1use super::kernel;
4use super::tensor::Tensor1D;
5use super::weights::{WeightTensor, Weights};
6use crate::backends::llm_policy::OptimizerKind;
7use anyhow::{Context, Result, bail};
8use serde_json::json;
9use std::fs::File;
10use std::io::Write;
11use std::path::Path;
12use wide::f32x8;
13
14#[derive(Debug, Clone)]
16pub struct Config {
17 pub vocab_size: usize,
19 pub hidden_size: usize,
21 pub num_layers: usize,
23 pub inner_size: usize,
25 pub state_size: usize,
27 pub conv_kernel: usize,
29 pub dt_rank: usize,
31 pub layer_norm_eps: f32,
33}
34
35impl Default for Config {
36 fn default() -> Self {
37 Self {
38 vocab_size: 256,
39 hidden_size: 256,
40 num_layers: 6,
41 inner_size: 512,
42 state_size: 16,
43 conv_kernel: 4,
44 dt_rank: 16,
45 layer_norm_eps: 1e-5,
46 }
47 }
48}
49
50impl Config {
51 pub fn validate(&self) -> Result<()> {
53 if self.vocab_size == 0 {
54 bail!("mamba vocab_size must be > 0");
55 }
56 if self.hidden_size == 0 {
57 bail!("mamba hidden_size must be > 0");
58 }
59 if self.num_layers == 0 {
60 bail!("mamba num_layers must be > 0");
61 }
62 if self.inner_size == 0 {
63 bail!("mamba inner_size must be > 0");
64 }
65 if self.state_size == 0 {
66 bail!("mamba state_size must be > 0");
67 }
68 if self.conv_kernel == 0 {
69 bail!("mamba conv_kernel must be > 0");
70 }
71 if self.dt_rank == 0 {
72 bail!("mamba dt_rank must be > 0");
73 }
74 Ok(())
75 }
76}
77
78#[derive(Clone)]
79struct LayerState {
80 conv: Tensor1D, conv_pos: usize,
82 ssm: Tensor1D, }
84
85impl LayerState {
86 fn new(cfg: &Config) -> Self {
87 Self {
88 conv: Tensor1D::zeros(cfg.inner_size * cfg.conv_kernel),
89 conv_pos: 0,
90 ssm: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
91 }
92 }
93
94 fn reset(&mut self) {
95 self.conv.zero();
96 self.conv_pos = 0;
97 self.ssm.zero();
98 }
99}
100
101#[derive(Clone)]
103pub struct State {
104 layers: Vec<LayerState>,
105}
106
107impl State {
108 pub fn new(cfg: &Config) -> Self {
110 Self {
111 layers: (0..cfg.num_layers).map(|_| LayerState::new(cfg)).collect(),
112 }
113 }
114
115 pub fn reset(&mut self) {
117 for l in &mut self.layers {
118 l.reset();
119 }
120 }
121}
122
123#[derive(Clone)]
124struct LayerWeights {
125 norm_w: Tensor1D,
126 norm_b: Option<Tensor1D>,
127
128 in_proj_w: Tensor1D, in_proj_b: Option<Tensor1D>,
130
131 conv_w: Tensor1D, conv_b: Option<Tensor1D>,
133
134 x_proj_w: Tensor1D, x_proj_b: Option<Tensor1D>,
136
137 dt_proj_w: Tensor1D, dt_proj_b: Tensor1D,
139
140 a_log: Tensor1D, a: Tensor1D, d: Tensor1D, out_proj_w: Tensor1D, out_proj_b: Option<Tensor1D>,
147}
148
149#[derive(Clone)]
150struct AdamTensorState {
151 m: Tensor1D,
152 v: Tensor1D,
153}
154
155impl AdamTensorState {
156 #[inline]
157 fn new(len: usize) -> Self {
158 Self {
159 m: Tensor1D::zeros(len),
160 v: Tensor1D::zeros(len),
161 }
162 }
163}
164
165#[derive(Clone)]
166struct LayerAdamState {
167 norm_w: AdamTensorState,
168 norm_b: Option<AdamTensorState>,
169 in_proj_w: AdamTensorState,
170 in_proj_b: Option<AdamTensorState>,
171 conv_w: AdamTensorState,
172 conv_b: Option<AdamTensorState>,
173 x_proj_w: AdamTensorState,
174 x_proj_b: Option<AdamTensorState>,
175 dt_proj_w: AdamTensorState,
176 dt_proj_b: AdamTensorState,
177 a: AdamTensorState,
178 d: AdamTensorState,
179 out_proj_w: AdamTensorState,
180 out_proj_b: Option<AdamTensorState>,
181}
182
183#[derive(Clone)]
184pub struct FullAdamState {
186 embeddings: AdamTensorState,
187 final_norm_w: AdamTensorState,
188 final_norm_b: Option<AdamTensorState>,
189 lm_head: AdamTensorState,
190 lm_head_b: Option<AdamTensorState>,
191 layers: Vec<LayerAdamState>,
192}
193
194#[derive(Clone, Copy, Debug, Default)]
195pub struct TrainScopeMask {
197 pub embed: bool,
199 pub layer_norm: bool,
201 pub mixer_conv: bool,
203 pub mixer_ssm: bool,
205 pub mixer_proj: bool,
207 pub head: bool,
209 pub bias: bool,
211}
212
213impl TrainScopeMask {
214 #[inline]
215 pub fn all() -> Self {
217 Self {
218 embed: true,
219 layer_norm: true,
220 mixer_conv: true,
221 mixer_ssm: true,
222 mixer_proj: true,
223 head: true,
224 bias: true,
225 }
226 }
227
228 #[inline]
229 pub fn trains_model_params(&self) -> bool {
231 self.embed
232 || self.layer_norm
233 || self.mixer_conv
234 || self.mixer_ssm
235 || self.mixer_proj
236 || self.head
237 }
238}
239
240struct AdamStep {
241 lr: f32,
242 clip: f32,
243 b1: f32,
244 b2: f32,
245 eps: f32,
246 bias_corr1: f32,
247 bias_corr2: f32,
248}
249
250#[derive(Clone)]
251struct LayerTrainTrace {
252 h_in: Tensor1D,
253 norm: Tensor1D,
254 xz: Tensor1D,
255 conv_pre: Tensor1D,
256 conv_post: Tensor1D,
257 conv_sigmoid: Tensor1D,
258 proj: Tensor1D,
259 dt_raw: Tensor1D,
260 dt: Tensor1D,
261 gate: Tensor1D,
262 gate_sigmoid: Tensor1D,
263 y_pre: Tensor1D,
264 y: Tensor1D,
265 out: Tensor1D,
266 d_a: Tensor1D,
267 ssm_prev: Tensor1D,
268 conv_prev: Tensor1D,
269 conv_pos_prev: usize,
270}
271
272impl LayerTrainTrace {
273 fn new(cfg: &Config) -> Self {
274 Self {
275 h_in: Tensor1D::zeros(cfg.hidden_size),
276 norm: Tensor1D::zeros(cfg.hidden_size),
277 xz: Tensor1D::zeros(cfg.inner_size * 2),
278 conv_pre: Tensor1D::zeros(cfg.inner_size),
279 conv_post: Tensor1D::zeros(cfg.inner_size),
280 conv_sigmoid: Tensor1D::zeros(cfg.inner_size),
281 proj: Tensor1D::zeros(cfg.dt_rank + 2 * cfg.state_size),
282 dt_raw: Tensor1D::zeros(cfg.inner_size),
283 dt: Tensor1D::zeros(cfg.inner_size),
284 gate: Tensor1D::zeros(cfg.inner_size),
285 gate_sigmoid: Tensor1D::zeros(cfg.inner_size),
286 y_pre: Tensor1D::zeros(cfg.inner_size),
287 y: Tensor1D::zeros(cfg.inner_size),
288 out: Tensor1D::zeros(cfg.hidden_size),
289 d_a: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
290 ssm_prev: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
291 conv_prev: Tensor1D::zeros(cfg.inner_size * cfg.conv_kernel),
292 conv_pos_prev: 0,
293 }
294 }
295}
296
297#[derive(Clone)]
298struct TokenTrainTrace {
299 token: usize,
300 norm: Tensor1D,
301 h_final: Tensor1D,
302 layers: Vec<LayerTrainTrace>,
303}
304
305impl TokenTrainTrace {
306 fn from_scratch(scratch: &ScratchBuffers) -> Self {
307 Self {
308 token: scratch.train_token,
309 norm: scratch.norm.clone(),
310 h_final: scratch.train_h_final.clone(),
311 layers: scratch.train_trace_layers.clone(),
312 }
313 }
314}
315
316#[derive(Clone)]
317struct LayerGradState {
318 norm_w: Tensor1D,
319 norm_b: Option<Tensor1D>,
320 in_proj_w: Tensor1D,
321 in_proj_b: Option<Tensor1D>,
322 conv_w: Tensor1D,
323 conv_b: Option<Tensor1D>,
324 x_proj_w: Tensor1D,
325 x_proj_b: Option<Tensor1D>,
326 dt_proj_w: Tensor1D,
327 dt_proj_b: Tensor1D,
328 a: Tensor1D,
329 d: Tensor1D,
330 out_proj_w: Tensor1D,
331 out_proj_b: Option<Tensor1D>,
332}
333
334#[derive(Clone)]
335struct FullGradState {
336 embeddings: Tensor1D,
337 final_norm_w: Tensor1D,
338 final_norm_b: Option<Tensor1D>,
339 lm_head: Tensor1D,
340 lm_head_b: Option<Tensor1D>,
341 layers: Vec<LayerGradState>,
342}
343
344#[derive(Clone)]
345struct LayerRecurrentGradState {
346 ssm_next: Tensor1D,
347 conv_next: Tensor1D,
348}
349
350impl LayerRecurrentGradState {
351 fn new(cfg: &Config) -> Self {
352 Self {
353 ssm_next: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
354 conv_next: Tensor1D::zeros(cfg.inner_size * cfg.conv_kernel),
355 }
356 }
357}
358
359#[derive(Clone)]
360struct RecurrentGradState {
361 layers: Vec<LayerRecurrentGradState>,
362}
363
364impl RecurrentGradState {
365 fn new(cfg: &Config) -> Self {
366 Self {
367 layers: (0..cfg.num_layers)
368 .map(|_| LayerRecurrentGradState::new(cfg))
369 .collect(),
370 }
371 }
372
373 fn zero(&mut self) {
374 for layer in &mut self.layers {
375 layer.ssm_next.zero();
376 layer.conv_next.zero();
377 }
378 }
379}
380
381#[derive(Clone)]
383pub struct Model {
384 cfg: Config,
385 embeddings: Tensor1D, final_norm_w: Tensor1D,
387 final_norm_b: Option<Tensor1D>,
388 lm_head: Tensor1D, lm_head_b: Option<Tensor1D>,
390 layers: Vec<LayerWeights>,
391}
392
393#[derive(Clone)]
395pub struct ScratchBuffers {
396 h: Tensor1D,
397 norm: Tensor1D,
398 xz: Tensor1D,
399 conv: Tensor1D,
400 proj: Tensor1D,
401 dt: Tensor1D,
402 y: Tensor1D,
403 out: Tensor1D,
404 logits: Tensor1D,
405 grad_h: Tensor1D,
406 grad_norm: Tensor1D,
407 grad_xz: Tensor1D,
408 grad_conv: Tensor1D,
409 grad_conv_pre: Tensor1D,
410 grad_proj: Tensor1D,
411 grad_dt_raw: Tensor1D,
412 grad_u: Tensor1D,
413 grad_b: Tensor1D,
414 grad_c: Tensor1D,
415 grad_ssm_d: Tensor1D,
416 grad_ssm_a: Tensor1D,
417 grad_ssm_a_log: Tensor1D,
418 grad_conv_w: Tensor1D,
419 grad_conv_b: Tensor1D,
420 grad_y: Tensor1D,
421 grad_out: Tensor1D,
422 grad_logits: Tensor1D,
423 grad_residual: Tensor1D,
424 train_trace_layers: Vec<LayerTrainTrace>,
425 train_h_final: Tensor1D,
426 train_token: usize,
427 train_trace_valid: bool,
428 capture_train_trace: bool,
429}
430
431impl ScratchBuffers {
432 pub fn new(cfg: &Config) -> Self {
434 let mut train_trace_layers = Vec::with_capacity(cfg.num_layers);
435 for _ in 0..cfg.num_layers {
436 train_trace_layers.push(LayerTrainTrace::new(cfg));
437 }
438 Self {
439 h: Tensor1D::zeros(cfg.hidden_size),
440 norm: Tensor1D::zeros(cfg.hidden_size),
441 xz: Tensor1D::zeros(cfg.inner_size * 2),
442 conv: Tensor1D::zeros(cfg.inner_size),
443 proj: Tensor1D::zeros(cfg.dt_rank + 2 * cfg.state_size),
444 dt: Tensor1D::zeros(cfg.inner_size),
445 y: Tensor1D::zeros(cfg.inner_size),
446 out: Tensor1D::zeros(cfg.hidden_size),
447 logits: Tensor1D::zeros(cfg.vocab_size),
448 grad_h: Tensor1D::zeros(cfg.hidden_size),
449 grad_norm: Tensor1D::zeros(cfg.hidden_size),
450 grad_xz: Tensor1D::zeros(cfg.inner_size * 2),
451 grad_conv: Tensor1D::zeros(cfg.inner_size),
452 grad_conv_pre: Tensor1D::zeros(cfg.inner_size),
453 grad_proj: Tensor1D::zeros(cfg.dt_rank + 2 * cfg.state_size),
454 grad_dt_raw: Tensor1D::zeros(cfg.inner_size),
455 grad_u: Tensor1D::zeros(cfg.dt_rank),
456 grad_b: Tensor1D::zeros(cfg.state_size),
457 grad_c: Tensor1D::zeros(cfg.state_size),
458 grad_ssm_d: Tensor1D::zeros(cfg.inner_size),
459 grad_ssm_a: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
460 grad_ssm_a_log: Tensor1D::zeros(cfg.inner_size * cfg.state_size),
461 grad_conv_w: Tensor1D::zeros(cfg.inner_size * cfg.conv_kernel),
462 grad_conv_b: Tensor1D::zeros(cfg.inner_size),
463 grad_y: Tensor1D::zeros(cfg.inner_size),
464 grad_out: Tensor1D::zeros(cfg.hidden_size),
465 grad_logits: Tensor1D::zeros(cfg.vocab_size),
466 grad_residual: Tensor1D::zeros(cfg.hidden_size),
467 train_trace_layers,
468 train_h_final: Tensor1D::zeros(cfg.hidden_size),
469 train_token: 0,
470 train_trace_valid: false,
471 capture_train_trace: false,
472 }
473 }
474
475 #[inline]
477 pub fn lm_head_input(&self) -> &[f32] {
478 self.norm.as_slice()
479 }
480
481 #[inline]
483 pub fn logits(&self) -> &[f32] {
484 self.logits.as_slice()
485 }
486
487 #[inline]
489 pub fn set_lm_head_input(&mut self, value: &[f32]) {
490 self.norm.as_mut_slice().copy_from_slice(value);
491 }
492
493 #[inline]
495 pub fn set_capture_train_trace(&mut self, enabled: bool) {
496 self.capture_train_trace = enabled;
497 if !enabled {
498 self.train_trace_valid = false;
499 }
500 }
501
502 #[inline]
504 pub fn has_train_trace(&self) -> bool {
505 self.train_trace_valid
506 }
507}
508
509impl Model {
510 pub fn new_full_adam_state(&self) -> FullAdamState {
512 let mut layers = Vec::with_capacity(self.layers.len());
513 for layer in &self.layers {
514 layers.push(LayerAdamState {
515 norm_w: AdamTensorState::new(layer.norm_w.len()),
516 norm_b: layer.norm_b.as_ref().map(|b| AdamTensorState::new(b.len())),
517 in_proj_w: AdamTensorState::new(layer.in_proj_w.len()),
518 in_proj_b: layer
519 .in_proj_b
520 .as_ref()
521 .map(|b| AdamTensorState::new(b.len())),
522 conv_w: AdamTensorState::new(layer.conv_w.len()),
523 conv_b: layer.conv_b.as_ref().map(|b| AdamTensorState::new(b.len())),
524 x_proj_w: AdamTensorState::new(layer.x_proj_w.len()),
525 x_proj_b: layer
526 .x_proj_b
527 .as_ref()
528 .map(|b| AdamTensorState::new(b.len())),
529 dt_proj_w: AdamTensorState::new(layer.dt_proj_w.len()),
530 dt_proj_b: AdamTensorState::new(layer.dt_proj_b.len()),
531 a: AdamTensorState::new(layer.a_log.len()),
532 d: AdamTensorState::new(layer.d.len()),
533 out_proj_w: AdamTensorState::new(layer.out_proj_w.len()),
534 out_proj_b: layer
535 .out_proj_b
536 .as_ref()
537 .map(|b| AdamTensorState::new(b.len())),
538 });
539 }
540
541 FullAdamState {
542 embeddings: AdamTensorState::new(self.embeddings.len()),
543 final_norm_w: AdamTensorState::new(self.final_norm_w.len()),
544 final_norm_b: self
545 .final_norm_b
546 .as_ref()
547 .map(|b| AdamTensorState::new(b.len())),
548 lm_head: AdamTensorState::new(self.lm_head.len()),
549 lm_head_b: self
550 .lm_head_b
551 .as_ref()
552 .map(|b| AdamTensorState::new(b.len())),
553 layers,
554 }
555 }
556
557 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
559 let weights = Weights::load(path.as_ref()).with_context(|| {
560 format!(
561 "failed to load model weights from {}",
562 path.as_ref().display()
563 )
564 })?;
565
566 if weights.get("backbone.embedding.weight").is_some() {
567 Self::load_official(&weights)
568 } else {
569 Self::load_native(&weights)
570 }
571 }
572
573 pub fn new_random(cfg: Config, seed: u64) -> Result<Self> {
575 cfg.validate()?;
576
577 let mut rng = MambaRng::new(seed);
578 let v = cfg.vocab_size;
579 let h = cfg.hidden_size;
580 let i = cfg.inner_size;
581 let s = cfg.state_size;
582 let k = cfg.conv_kernel;
583 let r = cfg.dt_rank;
584
585 let mut embeddings = Tensor1D::zeros(v * h);
586 init_uniform(&mut embeddings, &mut rng, 0.02);
587
588 let mut final_norm_w = Tensor1D::zeros(h);
589 init_const(&mut final_norm_w, 1.0);
590
591 let mut lm_head = Tensor1D::zeros(v * h);
592 init_uniform(&mut lm_head, &mut rng, 0.02);
593
594 let mut layers = Vec::with_capacity(cfg.num_layers);
595 for _ in 0..cfg.num_layers {
596 let mut norm_w = Tensor1D::zeros(h);
597 init_const(&mut norm_w, 1.0);
598
599 let mut in_proj_w = Tensor1D::zeros((2 * i) * h);
600 init_uniform(&mut in_proj_w, &mut rng, 0.02);
601 let mut in_proj_b = Tensor1D::zeros(2 * i);
602 init_const(&mut in_proj_b, 0.0);
603
604 let mut conv_w = Tensor1D::zeros(i * k);
605 init_uniform(&mut conv_w, &mut rng, 0.05);
606 let mut conv_b = Tensor1D::zeros(i);
607 init_const(&mut conv_b, 0.0);
608
609 let mut x_proj_w = Tensor1D::zeros((r + 2 * s) * i);
610 init_uniform(&mut x_proj_w, &mut rng, 0.02);
611
612 let mut dt_proj_w = Tensor1D::zeros(i * r);
613 init_uniform(&mut dt_proj_w, &mut rng, 0.02);
614 let mut dt_proj_b = Tensor1D::zeros(i);
615 init_const(&mut dt_proj_b, -2.0);
616
617 let mut a_log = Tensor1D::zeros(i * s);
618 init_const(&mut a_log, 0.0);
619 let a = a_from_a_log_tensor(&a_log);
620 let mut d = Tensor1D::zeros(i);
621 init_const(&mut d, 1.0);
622
623 let mut out_proj_w = Tensor1D::zeros(h * i);
624 init_uniform(&mut out_proj_w, &mut rng, 0.02);
625 let mut out_proj_b = Tensor1D::zeros(h);
626 init_const(&mut out_proj_b, 0.0);
627
628 layers.push(LayerWeights {
629 norm_w,
630 norm_b: None,
631 in_proj_w,
632 in_proj_b: Some(in_proj_b),
633 conv_w,
634 conv_b: Some(conv_b),
635 x_proj_w,
636 x_proj_b: None,
637 dt_proj_w,
638 dt_proj_b,
639 a_log,
640 a,
641 d,
642 out_proj_w,
643 out_proj_b: Some(out_proj_b),
644 });
645 }
646
647 Ok(Self {
648 cfg,
649 embeddings,
650 final_norm_w,
651 final_norm_b: None,
652 lm_head,
653 lm_head_b: None,
654 layers,
655 })
656 }
657
658 pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<()> {
660 #[derive(Clone)]
661 struct TensorRec {
662 name: String,
663 shape: Vec<usize>,
664 data: Vec<f32>,
665 }
666
667 let c = self.cfg.hidden_size;
668 let v = self.cfg.vocab_size;
669 let i = self.cfg.inner_size;
670 let s = self.cfg.state_size;
671 let k = self.cfg.conv_kernel;
672 let r = self.cfg.dt_rank;
673
674 let mut recs: Vec<TensorRec> = Vec::new();
675
676 let mut push = |name: String, shape: Vec<usize>, t: &Tensor1D| {
677 recs.push(TensorRec {
678 name,
679 shape,
680 data: t.as_slice().to_vec(),
681 });
682 };
683
684 push(
685 "model.embeddings.weight".to_string(),
686 vec![v, c],
687 &self.embeddings,
688 );
689 push("model.norm.weight".to_string(), vec![c], &self.final_norm_w);
690 if let Some(b) = &self.final_norm_b {
691 push("model.norm.bias".to_string(), vec![c], b);
692 }
693 push("lm_head.weight".to_string(), vec![v, c], &self.lm_head);
694 if let Some(b) = &self.lm_head_b {
695 push("lm_head.bias".to_string(), vec![v], b);
696 }
697
698 for (idx, layer) in self.layers.iter().enumerate() {
699 let pfx = format!("model.layers.{idx}.mixer");
700 push(
701 format!("model.layers.{idx}.norm.weight"),
702 vec![c],
703 &layer.norm_w,
704 );
705 if let Some(b) = &layer.norm_b {
706 push(format!("model.layers.{idx}.norm.bias"), vec![c], b);
707 }
708 push(
709 format!("{pfx}.in_proj.weight"),
710 vec![2 * i, c],
711 &layer.in_proj_w,
712 );
713 if let Some(b) = &layer.in_proj_b {
714 push(format!("{pfx}.in_proj.bias"), vec![2 * i], b);
715 }
716 push(format!("{pfx}.conv1d.weight"), vec![i, 1, k], &layer.conv_w);
717 if let Some(b) = &layer.conv_b {
718 push(format!("{pfx}.conv1d.bias"), vec![i], b);
719 }
720 push(
721 format!("{pfx}.x_proj.weight"),
722 vec![r + 2 * s, i],
723 &layer.x_proj_w,
724 );
725 if let Some(b) = &layer.x_proj_b {
726 push(format!("{pfx}.x_proj.bias"), vec![r + 2 * s], b);
727 }
728 push(
729 format!("{pfx}.dt_proj.weight"),
730 vec![i, r],
731 &layer.dt_proj_w,
732 );
733 push(format!("{pfx}.dt_proj.bias"), vec![i], &layer.dt_proj_b);
734 push(format!("{pfx}.A_log"), vec![i, s], &layer.a_log);
735 push(format!("{pfx}.D"), vec![i], &layer.d);
736 push(
737 format!("{pfx}.out_proj.weight"),
738 vec![c, i],
739 &layer.out_proj_w,
740 );
741 if let Some(b) = &layer.out_proj_b {
742 push(format!("{pfx}.out_proj.bias"), vec![c], b);
743 }
744 }
745
746 recs.sort_by(|a, b| a.name.cmp(&b.name));
747
748 let mut offset = 0usize;
749 let mut header = serde_json::Map::new();
750 header.insert("__metadata__".to_string(), json!({}));
751 for rec in &recs {
752 let bytes = rec.data.len() * 4;
753 header.insert(
754 rec.name.clone(),
755 json!({
756 "dtype": "F32",
757 "shape": rec.shape,
758 "data_offsets": [offset, offset + bytes],
759 }),
760 );
761 offset += bytes;
762 }
763
764 let header_bytes = serde_json::to_vec(&header)?;
765 let mut f = File::create(path)?;
766 f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
767 f.write_all(&header_bytes)?;
768 for rec in &recs {
769 for v in &rec.data {
770 f.write_all(&v.to_le_bytes())?;
771 }
772 }
773 Ok(())
774 }
775
776 pub fn save_full_adam_safetensors<P: AsRef<Path>>(
778 &self,
779 adam: &FullAdamState,
780 path: P,
781 ) -> Result<()> {
782 #[derive(Clone)]
783 struct TensorRec {
784 name: String,
785 shape: Vec<usize>,
786 data: Vec<f32>,
787 }
788
789 let c = self.cfg.hidden_size;
790 let v = self.cfg.vocab_size;
791 let i = self.cfg.inner_size;
792 let s = self.cfg.state_size;
793 let k = self.cfg.conv_kernel;
794 let r = self.cfg.dt_rank;
795
796 let mut recs: Vec<TensorRec> = Vec::new();
797 let mut push_state = |name_prefix: &str, shape: Vec<usize>, st: &AdamTensorState| {
798 recs.push(TensorRec {
799 name: format!("{name_prefix}.m"),
800 shape: shape.clone(),
801 data: st.m.as_slice().to_vec(),
802 });
803 recs.push(TensorRec {
804 name: format!("{name_prefix}.v"),
805 shape,
806 data: st.v.as_slice().to_vec(),
807 });
808 };
809
810 push_state("opt.embeddings", vec![v, c], &adam.embeddings);
811 push_state("opt.final_norm.weight", vec![c], &adam.final_norm_w);
812 if let Some(b) = &adam.final_norm_b {
813 push_state("opt.final_norm.bias", vec![c], b);
814 }
815 push_state("opt.lm_head.weight", vec![v, c], &adam.lm_head);
816 if let Some(b) = &adam.lm_head_b {
817 push_state("opt.lm_head.bias", vec![v], b);
818 }
819
820 for (idx, layer) in adam.layers.iter().enumerate() {
821 let pfx = format!("opt.layers.{idx}");
822 push_state(&format!("{pfx}.norm.weight"), vec![c], &layer.norm_w);
823 if let Some(b) = &layer.norm_b {
824 push_state(&format!("{pfx}.norm.bias"), vec![c], b);
825 }
826 push_state(
827 &format!("{pfx}.in_proj.weight"),
828 vec![2 * i, c],
829 &layer.in_proj_w,
830 );
831 if let Some(b) = &layer.in_proj_b {
832 push_state(&format!("{pfx}.in_proj.bias"), vec![2 * i], b);
833 }
834 push_state(
835 &format!("{pfx}.conv1d.weight"),
836 vec![i, 1, k],
837 &layer.conv_w,
838 );
839 if let Some(b) = &layer.conv_b {
840 push_state(&format!("{pfx}.conv1d.bias"), vec![i], b);
841 }
842 push_state(
843 &format!("{pfx}.x_proj.weight"),
844 vec![r + 2 * s, i],
845 &layer.x_proj_w,
846 );
847 if let Some(b) = &layer.x_proj_b {
848 push_state(&format!("{pfx}.x_proj.bias"), vec![r + 2 * s], b);
849 }
850 push_state(
851 &format!("{pfx}.dt_proj.weight"),
852 vec![i, r],
853 &layer.dt_proj_w,
854 );
855 push_state(&format!("{pfx}.dt_proj.bias"), vec![i], &layer.dt_proj_b);
856 push_state(&format!("{pfx}.A_log"), vec![i, s], &layer.a);
857 push_state(&format!("{pfx}.D"), vec![i], &layer.d);
858 push_state(
859 &format!("{pfx}.out_proj.weight"),
860 vec![c, i],
861 &layer.out_proj_w,
862 );
863 if let Some(b) = &layer.out_proj_b {
864 push_state(&format!("{pfx}.out_proj.bias"), vec![c], b);
865 }
866 }
867
868 recs.sort_by(|a, b| a.name.cmp(&b.name));
869
870 let mut offset = 0usize;
871 let mut header = serde_json::Map::new();
872 header.insert("__metadata__".to_string(), json!({}));
873 for rec in &recs {
874 let bytes = rec.data.len() * 4;
875 header.insert(
876 rec.name.clone(),
877 json!({
878 "dtype": "F32",
879 "shape": rec.shape,
880 "data_offsets": [offset, offset + bytes],
881 }),
882 );
883 offset += bytes;
884 }
885
886 let header_bytes = serde_json::to_vec(&header)?;
887 let mut f = File::create(path)?;
888 f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
889 f.write_all(&header_bytes)?;
890 for rec in &recs {
891 for v in &rec.data {
892 f.write_all(&v.to_le_bytes())?;
893 }
894 }
895 Ok(())
896 }
897
898 pub fn load_full_adam_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<FullAdamState> {
900 let weights = Weights::load(path.as_ref()).with_context(|| {
901 format!(
902 "failed to load optimizer moments from {}",
903 path.as_ref().display()
904 )
905 })?;
906 let mut adam = self.new_full_adam_state();
907
908 let load_state = |name_prefix: &str, st: &mut AdamTensorState| -> Result<()> {
909 let m_name = format!("{name_prefix}.m");
910 let v_name = format!("{name_prefix}.v");
911 let m_t = weights
912 .require(&m_name)
913 .with_context(|| format!("missing optimizer tensor '{m_name}'"))?;
914 let v_t = weights
915 .require(&v_name)
916 .with_context(|| format!("missing optimizer tensor '{v_name}'"))?;
917 if m_t.data().len() != st.m.len() {
918 bail!(
919 "optimizer tensor '{}' len {} != expected {}",
920 m_name,
921 m_t.data().len(),
922 st.m.len()
923 );
924 }
925 if v_t.data().len() != st.v.len() {
926 bail!(
927 "optimizer tensor '{}' len {} != expected {}",
928 v_name,
929 v_t.data().len(),
930 st.v.len()
931 );
932 }
933 st.m.as_mut_slice().copy_from_slice(m_t.data());
934 st.v.as_mut_slice().copy_from_slice(v_t.data());
935 Ok(())
936 };
937
938 load_state("opt.embeddings", &mut adam.embeddings)?;
939 load_state("opt.final_norm.weight", &mut adam.final_norm_w)?;
940 if let Some(st) = adam.final_norm_b.as_mut() {
941 load_state("opt.final_norm.bias", st)?;
942 }
943 load_state("opt.lm_head.weight", &mut adam.lm_head)?;
944 if let Some(st) = adam.lm_head_b.as_mut() {
945 load_state("opt.lm_head.bias", st)?;
946 }
947
948 for (idx, layer) in adam.layers.iter_mut().enumerate() {
949 let pfx = format!("opt.layers.{idx}");
950 load_state(&format!("{pfx}.norm.weight"), &mut layer.norm_w)?;
951 if let Some(st) = layer.norm_b.as_mut() {
952 load_state(&format!("{pfx}.norm.bias"), st)?;
953 }
954 load_state(&format!("{pfx}.in_proj.weight"), &mut layer.in_proj_w)?;
955 if let Some(st) = layer.in_proj_b.as_mut() {
956 load_state(&format!("{pfx}.in_proj.bias"), st)?;
957 }
958 load_state(&format!("{pfx}.conv1d.weight"), &mut layer.conv_w)?;
959 if let Some(st) = layer.conv_b.as_mut() {
960 load_state(&format!("{pfx}.conv1d.bias"), st)?;
961 }
962 load_state(&format!("{pfx}.x_proj.weight"), &mut layer.x_proj_w)?;
963 if let Some(st) = layer.x_proj_b.as_mut() {
964 load_state(&format!("{pfx}.x_proj.bias"), st)?;
965 }
966 load_state(&format!("{pfx}.dt_proj.weight"), &mut layer.dt_proj_w)?;
967 load_state(&format!("{pfx}.dt_proj.bias"), &mut layer.dt_proj_b)?;
968 load_state(&format!("{pfx}.A_log"), &mut layer.a)?;
969 load_state(&format!("{pfx}.D"), &mut layer.d)?;
970 load_state(&format!("{pfx}.out_proj.weight"), &mut layer.out_proj_w)?;
971 if let Some(st) = layer.out_proj_b.as_mut() {
972 load_state(&format!("{pfx}.out_proj.bias"), st)?;
973 }
974 }
975
976 Ok(adam)
977 }
978
979 #[inline]
981 pub fn config(&self) -> &Config {
982 &self.cfg
983 }
984
985 #[inline]
987 pub fn new_state(&self) -> State {
988 State::new(&self.cfg)
989 }
990
991 #[inline]
993 pub fn lm_head_weights(&self) -> &[f32] {
994 self.lm_head.as_slice()
995 }
996
997 #[inline]
999 pub fn lm_head_weights_mut(&mut self) -> &mut [f32] {
1000 self.lm_head.as_mut_slice()
1001 }
1002
1003 fn new_full_grad_state(&self) -> FullGradState {
1004 let mut layers = Vec::with_capacity(self.layers.len());
1005 for layer in &self.layers {
1006 layers.push(LayerGradState {
1007 norm_w: Tensor1D::zeros(layer.norm_w.len()),
1008 norm_b: layer.norm_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1009 in_proj_w: Tensor1D::zeros(layer.in_proj_w.len()),
1010 in_proj_b: layer.in_proj_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1011 conv_w: Tensor1D::zeros(layer.conv_w.len()),
1012 conv_b: layer.conv_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1013 x_proj_w: Tensor1D::zeros(layer.x_proj_w.len()),
1014 x_proj_b: layer.x_proj_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1015 dt_proj_w: Tensor1D::zeros(layer.dt_proj_w.len()),
1016 dt_proj_b: Tensor1D::zeros(layer.dt_proj_b.len()),
1017 a: Tensor1D::zeros(layer.a.len()),
1018 d: Tensor1D::zeros(layer.d.len()),
1019 out_proj_w: Tensor1D::zeros(layer.out_proj_w.len()),
1020 out_proj_b: layer.out_proj_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1021 });
1022 }
1023 FullGradState {
1024 embeddings: Tensor1D::zeros(self.embeddings.len()),
1025 final_norm_w: Tensor1D::zeros(self.final_norm_w.len()),
1026 final_norm_b: self.final_norm_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1027 lm_head: Tensor1D::zeros(self.lm_head.len()),
1028 lm_head_b: self.lm_head_b.as_ref().map(|b| Tensor1D::zeros(b.len())),
1029 layers,
1030 }
1031 }
1032
1033 fn new_recurrent_grad_state(&self) -> RecurrentGradState {
1034 RecurrentGradState::new(&self.cfg)
1035 }
1036
1037 #[allow(clippy::too_many_arguments)]
1038 fn apply_full_gradients(
1039 &mut self,
1040 grads: &FullGradState,
1041 scope: TrainScopeMask,
1042 optimizer: OptimizerKind,
1043 lr: f32,
1044 clip: f32,
1045 adam_t: &mut usize,
1046 model_adam: Option<&mut FullAdamState>,
1047 out_bias: Option<&mut [f32]>,
1048 out_bias_grad: Option<&[f32]>,
1049 out_bias_adam_m: Option<&mut [f32]>,
1050 out_bias_adam_v: Option<&mut [f32]>,
1051 ) -> Result<()> {
1052 let mut adam_step = None::<AdamStep>;
1053 let mut model_adam = model_adam;
1054 if matches!(optimizer, OptimizerKind::Adam) {
1055 *adam_t = adam_t.saturating_add(1);
1056 let t = (*adam_t).max(1) as i32;
1057 let b1 = 0.9f32;
1058 let b2 = 0.999f32;
1059 adam_step = Some(AdamStep {
1060 lr,
1061 clip: clip.max(0.0),
1062 b1,
1063 b2,
1064 eps: 1e-8,
1065 bias_corr1: 1.0 - b1.powi(t),
1066 bias_corr2: 1.0 - b2.powi(t),
1067 });
1068 if scope.trains_model_params() && model_adam.is_none() {
1069 bail!("mamba Adam full-training state is missing");
1070 }
1071 }
1072
1073 if scope.bias
1074 && let (Some(bias), Some(grad)) = (out_bias, out_bias_grad)
1075 {
1076 match optimizer {
1077 OptimizerKind::Sgd => sgd_vec_update(bias, grad, lr, clip),
1078 OptimizerKind::Adam => {
1079 let cfg = adam_step.as_ref().expect("adam cfg initialized");
1080 let Some(m) = out_bias_adam_m else {
1081 bail!("mamba Adam output-bias moments are missing");
1082 };
1083 let Some(v) = out_bias_adam_v else {
1084 bail!("mamba Adam output-bias moments are missing");
1085 };
1086 apply_adam_vec_update_raw(bias, grad, m, v, cfg);
1087 }
1088 }
1089 }
1090
1091 if scope.head {
1092 match optimizer {
1093 OptimizerKind::Sgd => {
1094 sgd_vec_update(
1095 self.lm_head.as_mut_slice(),
1096 grads.lm_head.as_slice(),
1097 lr,
1098 clip,
1099 );
1100 if let (Some(b), Some(gb)) = (self.lm_head_b.as_mut(), grads.lm_head_b.as_ref())
1101 {
1102 sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1103 }
1104 }
1105 OptimizerKind::Adam => {
1106 let cfg = adam_step.as_ref().expect("adam cfg initialized");
1107 let adam = model_adam.as_mut().expect("adam state exists");
1108 apply_adam_vec_update(
1109 self.lm_head.as_mut_slice(),
1110 grads.lm_head.as_slice(),
1111 &mut adam.lm_head,
1112 cfg,
1113 );
1114 if let (Some(b), Some(gb), Some(ab)) = (
1115 self.lm_head_b.as_mut(),
1116 grads.lm_head_b.as_ref(),
1117 adam.lm_head_b.as_mut(),
1118 ) {
1119 apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1120 }
1121 }
1122 }
1123 }
1124
1125 if scope.layer_norm {
1126 match optimizer {
1127 OptimizerKind::Sgd => {
1128 sgd_vec_update(
1129 self.final_norm_w.as_mut_slice(),
1130 grads.final_norm_w.as_slice(),
1131 lr,
1132 clip,
1133 );
1134 if let (Some(b), Some(gb)) =
1135 (self.final_norm_b.as_mut(), grads.final_norm_b.as_ref())
1136 {
1137 sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1138 }
1139 }
1140 OptimizerKind::Adam => {
1141 let cfg = adam_step.as_ref().expect("adam cfg initialized");
1142 let adam = model_adam.as_mut().expect("adam state exists");
1143 apply_adam_vec_update(
1144 self.final_norm_w.as_mut_slice(),
1145 grads.final_norm_w.as_slice(),
1146 &mut adam.final_norm_w,
1147 cfg,
1148 );
1149 if let (Some(b), Some(gb), Some(ab)) = (
1150 self.final_norm_b.as_mut(),
1151 grads.final_norm_b.as_ref(),
1152 adam.final_norm_b.as_mut(),
1153 ) {
1154 apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1155 }
1156 }
1157 }
1158 }
1159
1160 for layer_idx in 0..self.cfg.num_layers {
1161 let layer = &mut self.layers[layer_idx];
1162 let grad = &grads.layers[layer_idx];
1163 match optimizer {
1164 OptimizerKind::Sgd => {
1165 if scope.layer_norm {
1166 sgd_vec_update(
1167 layer.norm_w.as_mut_slice(),
1168 grad.norm_w.as_slice(),
1169 lr,
1170 clip,
1171 );
1172 if let (Some(b), Some(gb)) = (layer.norm_b.as_mut(), grad.norm_b.as_ref()) {
1173 sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1174 }
1175 }
1176 if scope.mixer_proj {
1177 sgd_vec_update(
1178 layer.in_proj_w.as_mut_slice(),
1179 grad.in_proj_w.as_slice(),
1180 lr,
1181 clip,
1182 );
1183 if let (Some(b), Some(gb)) =
1184 (layer.in_proj_b.as_mut(), grad.in_proj_b.as_ref())
1185 {
1186 sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1187 }
1188 sgd_vec_update(
1189 layer.x_proj_w.as_mut_slice(),
1190 grad.x_proj_w.as_slice(),
1191 lr,
1192 clip,
1193 );
1194 if let (Some(b), Some(gb)) =
1195 (layer.x_proj_b.as_mut(), grad.x_proj_b.as_ref())
1196 {
1197 sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1198 }
1199 sgd_vec_update(
1200 layer.dt_proj_w.as_mut_slice(),
1201 grad.dt_proj_w.as_slice(),
1202 lr,
1203 clip,
1204 );
1205 sgd_vec_update(
1206 layer.dt_proj_b.as_mut_slice(),
1207 grad.dt_proj_b.as_slice(),
1208 lr,
1209 clip,
1210 );
1211 sgd_vec_update(
1212 layer.out_proj_w.as_mut_slice(),
1213 grad.out_proj_w.as_slice(),
1214 lr,
1215 clip,
1216 );
1217 if let (Some(b), Some(gb)) =
1218 (layer.out_proj_b.as_mut(), grad.out_proj_b.as_ref())
1219 {
1220 sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1221 }
1222 }
1223 if scope.mixer_conv {
1224 sgd_vec_update(
1225 layer.conv_w.as_mut_slice(),
1226 grad.conv_w.as_slice(),
1227 lr,
1228 clip,
1229 );
1230 if let (Some(b), Some(gb)) = (layer.conv_b.as_mut(), grad.conv_b.as_ref()) {
1231 sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
1232 }
1233 }
1234 if scope.mixer_ssm {
1235 sgd_vec_update(layer.d.as_mut_slice(), grad.d.as_slice(), lr, clip);
1236 for idx in 0..layer.a_log.len().min(grad.a.len()) {
1237 let mut g = grad.a[idx] * layer.a[idx];
1238 if clip > 0.0 {
1239 g = g.clamp(-clip, clip);
1240 }
1241 let new_log = layer.a_log[idx] + lr * g;
1242 layer.a_log[idx] = new_log;
1243 layer.a[idx] = -new_log.exp();
1244 }
1245 }
1246 }
1247 OptimizerKind::Adam => {
1248 let cfg = adam_step.as_ref().expect("adam cfg initialized");
1249 let adam =
1250 &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
1251 if scope.layer_norm {
1252 apply_adam_vec_update(
1253 layer.norm_w.as_mut_slice(),
1254 grad.norm_w.as_slice(),
1255 &mut adam.norm_w,
1256 cfg,
1257 );
1258 if let (Some(b), Some(gb), Some(ab)) = (
1259 layer.norm_b.as_mut(),
1260 grad.norm_b.as_ref(),
1261 adam.norm_b.as_mut(),
1262 ) {
1263 apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1264 }
1265 }
1266 if scope.mixer_proj {
1267 apply_adam_vec_update(
1268 layer.in_proj_w.as_mut_slice(),
1269 grad.in_proj_w.as_slice(),
1270 &mut adam.in_proj_w,
1271 cfg,
1272 );
1273 if let (Some(b), Some(gb), Some(ab)) = (
1274 layer.in_proj_b.as_mut(),
1275 grad.in_proj_b.as_ref(),
1276 adam.in_proj_b.as_mut(),
1277 ) {
1278 apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1279 }
1280 apply_adam_vec_update(
1281 layer.x_proj_w.as_mut_slice(),
1282 grad.x_proj_w.as_slice(),
1283 &mut adam.x_proj_w,
1284 cfg,
1285 );
1286 if let (Some(b), Some(gb), Some(ab)) = (
1287 layer.x_proj_b.as_mut(),
1288 grad.x_proj_b.as_ref(),
1289 adam.x_proj_b.as_mut(),
1290 ) {
1291 apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1292 }
1293 apply_adam_vec_update(
1294 layer.dt_proj_w.as_mut_slice(),
1295 grad.dt_proj_w.as_slice(),
1296 &mut adam.dt_proj_w,
1297 cfg,
1298 );
1299 apply_adam_vec_update(
1300 layer.dt_proj_b.as_mut_slice(),
1301 grad.dt_proj_b.as_slice(),
1302 &mut adam.dt_proj_b,
1303 cfg,
1304 );
1305 apply_adam_vec_update(
1306 layer.out_proj_w.as_mut_slice(),
1307 grad.out_proj_w.as_slice(),
1308 &mut adam.out_proj_w,
1309 cfg,
1310 );
1311 if let (Some(b), Some(gb), Some(ab)) = (
1312 layer.out_proj_b.as_mut(),
1313 grad.out_proj_b.as_ref(),
1314 adam.out_proj_b.as_mut(),
1315 ) {
1316 apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1317 }
1318 }
1319 if scope.mixer_conv {
1320 apply_adam_vec_update(
1321 layer.conv_w.as_mut_slice(),
1322 grad.conv_w.as_slice(),
1323 &mut adam.conv_w,
1324 cfg,
1325 );
1326 if let (Some(b), Some(gb), Some(ab)) = (
1327 layer.conv_b.as_mut(),
1328 grad.conv_b.as_ref(),
1329 adam.conv_b.as_mut(),
1330 ) {
1331 apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
1332 }
1333 }
1334 if scope.mixer_ssm {
1335 apply_adam_vec_update(
1336 layer.d.as_mut_slice(),
1337 grad.d.as_slice(),
1338 &mut adam.d,
1339 cfg,
1340 );
1341 let mut grad_log = vec![0.0f32; grad.a.len().min(layer.a.len())];
1342 for idx in 0..grad_log.len() {
1343 grad_log[idx] = grad.a[idx] * layer.a[idx];
1344 }
1345 apply_adam_vec_update_and_sync_neg_exp(
1346 layer.a_log.as_mut_slice(),
1347 layer.a.as_mut_slice(),
1348 &grad_log,
1349 &mut adam.a,
1350 cfg,
1351 );
1352 }
1353 }
1354 }
1355 }
1356
1357 if scope.embed {
1358 match optimizer {
1359 OptimizerKind::Sgd => {
1360 sgd_vec_update(
1361 self.embeddings.as_mut_slice(),
1362 grads.embeddings.as_slice(),
1363 lr,
1364 clip,
1365 );
1366 }
1367 OptimizerKind::Adam => {
1368 let cfg = adam_step.as_ref().expect("adam cfg initialized");
1369 let adam = model_adam.as_mut().expect("adam state exists");
1370 apply_adam_vec_update(
1371 self.embeddings.as_mut_slice(),
1372 grads.embeddings.as_slice(),
1373 &mut adam.embeddings,
1374 cfg,
1375 );
1376 }
1377 }
1378 }
1379
1380 Ok(())
1381 }
1382
1383 #[allow(clippy::too_many_arguments, clippy::needless_range_loop)]
1384 fn accumulate_token_step_gradients(
1385 &self,
1386 scratch: &mut ScratchBuffers,
1387 trace: &TokenTrainTrace,
1388 state_new: &State,
1389 symbol: u8,
1390 pdf: &[f64],
1391 grad_scale: f32,
1392 scope: TrainScopeMask,
1393 grads: &mut FullGradState,
1394 out_bias_grad: Option<&mut [f32]>,
1395 future: &mut RecurrentGradState,
1396 ) -> Result<()> {
1397 let c = self.cfg.hidden_size;
1398 let i = self.cfg.inner_size;
1399 let s = self.cfg.state_size;
1400 let r = self.cfg.dt_rank;
1401 let k = self.cfg.conv_kernel;
1402 let v = self.cfg.vocab_size.min(pdf.len());
1403 if v == 0 {
1404 return Ok(());
1405 }
1406
1407 scratch.grad_logits.zero();
1408 for tok in 0..v {
1409 let p = pdf[tok].clamp(1e-12, 1.0) as f32;
1410 let target = if tok == symbol as usize { 1.0 } else { 0.0 };
1411 scratch.grad_logits[tok] = (target - p) * grad_scale;
1412 }
1413
1414 if scope.bias
1415 && let Some(bias_grad) = out_bias_grad
1416 {
1417 add_vec_grad(&mut bias_grad[0..v], &scratch.grad_logits.as_slice()[0..v]);
1418 }
1419
1420 scratch.grad_h.zero();
1421 if scope.head {
1422 add_outer_grad(
1423 grads.lm_head.as_mut_slice(),
1424 v,
1425 c,
1426 &scratch.grad_logits.as_slice()[0..v],
1427 trace.norm.as_slice(),
1428 );
1429 if let Some(lm_head_b) = grads.lm_head_b.as_mut() {
1430 let n = v.min(lm_head_b.len());
1431 add_vec_grad(
1432 &mut lm_head_b.as_mut_slice()[0..n],
1433 &scratch.grad_logits.as_slice()[0..n],
1434 );
1435 }
1436 }
1437 for tok in 0..v {
1438 let g = scratch.grad_logits[tok];
1439 if g == 0.0 {
1440 continue;
1441 }
1442 let row_off = tok * c;
1443 for col in 0..c {
1444 scratch.grad_h[col] += self.lm_head[row_off + col] * g;
1445 }
1446 }
1447
1448 let needs_backprop = scope.embed
1449 || scope.layer_norm
1450 || scope.mixer_conv
1451 || scope.mixer_ssm
1452 || scope.mixer_proj;
1453 if !needs_backprop {
1454 return Ok(());
1455 }
1456
1457 rms_norm_backward(
1458 trace.h_final.as_slice(),
1459 self.final_norm_w.as_slice(),
1460 scratch.grad_h.as_slice(),
1461 self.cfg.layer_norm_eps,
1462 scratch.grad_norm.as_mut_slice(),
1463 scratch.grad_out.as_mut_slice(),
1464 );
1465 if scope.layer_norm {
1466 add_vec_grad(
1467 grads.final_norm_w.as_mut_slice(),
1468 scratch.grad_out.as_slice(),
1469 );
1470 if let Some(final_norm_b) = grads.final_norm_b.as_mut() {
1471 add_vec_grad(final_norm_b.as_mut_slice(), scratch.grad_h.as_slice());
1472 }
1473 }
1474 scratch
1475 .grad_h
1476 .as_mut_slice()
1477 .copy_from_slice(scratch.grad_norm.as_slice());
1478
1479 for layer_idx in (0..self.cfg.num_layers).rev() {
1480 let tr = &trace.layers[layer_idx];
1481 let st_new = &state_new.layers[layer_idx];
1482 let layer = &self.layers[layer_idx];
1483 let layer_grads = &mut grads.layers[layer_idx];
1484 let future_layer = &mut future.layers[layer_idx];
1485
1486 scratch
1487 .grad_out
1488 .as_mut_slice()
1489 .copy_from_slice(scratch.grad_h.as_slice());
1490
1491 unsafe {
1492 kernel::gemv_t(
1493 layer.out_proj_w.as_ptr(),
1494 scratch.grad_out.as_ptr(),
1495 scratch.grad_y.as_mut_ptr(),
1496 c,
1497 i,
1498 );
1499 }
1500 if scope.mixer_proj {
1501 add_outer_grad(
1502 layer_grads.out_proj_w.as_mut_slice(),
1503 c,
1504 i,
1505 scratch.grad_out.as_slice(),
1506 tr.y.as_slice(),
1507 );
1508 if let Some(out_proj_b) = layer_grads.out_proj_b.as_mut() {
1509 add_vec_grad(out_proj_b.as_mut_slice(), scratch.grad_out.as_slice());
1510 }
1511 }
1512
1513 scratch
1514 .grad_residual
1515 .as_mut_slice()
1516 .copy_from_slice(scratch.grad_out.as_slice());
1517 scratch.grad_xz.zero();
1518 scratch.grad_b.zero();
1519 scratch.grad_c.zero();
1520 scratch.grad_ssm_d.zero();
1521 scratch.grad_ssm_a.zero();
1522 scratch.grad_dt_raw.zero();
1523 scratch.grad_conv.zero();
1524 scratch.grad_conv_pre.zero();
1525 scratch.grad_conv_w.zero();
1526 scratch.grad_conv_b.zero();
1527
1528 for ch in 0..i {
1529 let g_y = scratch.grad_y[ch];
1530 let g_y_pre = g_y * tr.gate[ch];
1531 let g_gate = g_y * tr.y_pre[ch];
1532 scratch.grad_xz[i + ch] =
1533 g_gate * silu_grad_from_sigmoid(tr.xz[i + ch], tr.gate_sigmoid[ch]);
1534
1535 let conv = tr.conv_post[ch];
1536 let dt = tr.dt[ch];
1537 let xdt = conv * dt;
1538 let mut g_xdt = 0.0f32;
1539 let mut g_dt = 0.0f32;
1540
1541 scratch.grad_conv[ch] = g_y_pre * layer.d[ch];
1542 if scope.mixer_ssm {
1543 scratch.grad_ssm_d[ch] = g_y_pre * conv;
1544 }
1545
1546 let row = ch * s;
1547 for j in 0..s {
1548 let idx = row + j;
1549 let c_j = tr.proj[r + s + j];
1550 let b_j = tr.proj[r + j];
1551 let s_prev = tr.ssm_prev[idx];
1552 let s_new = st_new.ssm[idx];
1553 let a_ij = layer.a[idx];
1554 let d_a = tr.d_a[idx];
1555
1556 let g_ssm_new = g_y_pre * c_j + future_layer.ssm_next[idx];
1557 scratch.grad_c[j] += g_y_pre * s_new;
1558 g_xdt += g_ssm_new * b_j;
1559 scratch.grad_b[j] += g_ssm_new * xdt;
1560
1561 let g_d_a = g_ssm_new * s_prev;
1562 g_dt += g_d_a * d_a * a_ij;
1563 if scope.mixer_ssm {
1564 scratch.grad_ssm_a[idx] += g_d_a * d_a * dt;
1565 }
1566 future_layer.ssm_next[idx] = g_ssm_new * d_a;
1567 }
1568
1569 scratch.grad_conv[ch] += g_xdt * dt;
1570 g_dt += g_xdt * conv;
1571 let dt_pre = tr.dt_raw[ch] + layer.dt_proj_b[ch];
1572 scratch.grad_dt_raw[ch] = g_dt * sigmoid(dt_pre);
1573 }
1574
1575 if scope.mixer_ssm {
1576 add_vec_grad(layer_grads.d.as_mut_slice(), scratch.grad_ssm_d.as_slice());
1577 add_vec_grad(layer_grads.a.as_mut_slice(), scratch.grad_ssm_a.as_slice());
1578 }
1579
1580 unsafe {
1581 kernel::gemv_t(
1582 layer.dt_proj_w.as_ptr(),
1583 scratch.grad_dt_raw.as_ptr(),
1584 scratch.grad_u.as_mut_ptr(),
1585 i,
1586 r,
1587 );
1588 }
1589 if scope.mixer_proj {
1590 add_outer_grad(
1591 layer_grads.dt_proj_w.as_mut_slice(),
1592 i,
1593 r,
1594 scratch.grad_dt_raw.as_slice(),
1595 &tr.proj.as_slice()[0..r],
1596 );
1597 add_vec_grad(
1598 layer_grads.dt_proj_b.as_mut_slice(),
1599 scratch.grad_dt_raw.as_slice(),
1600 );
1601 }
1602
1603 for kk in 0..r {
1604 scratch.grad_proj[kk] = scratch.grad_u[kk];
1605 }
1606 for j in 0..s {
1607 scratch.grad_proj[r + j] = scratch.grad_b[j];
1608 scratch.grad_proj[r + s + j] = scratch.grad_c[j];
1609 }
1610
1611 unsafe {
1612 kernel::gemv_t(
1613 layer.x_proj_w.as_ptr(),
1614 scratch.grad_proj.as_ptr(),
1615 scratch.grad_conv_pre.as_mut_ptr(),
1616 r + 2 * s,
1617 i,
1618 );
1619 kernel::add_inplace(
1620 scratch.grad_conv.as_mut_ptr(),
1621 scratch.grad_conv_pre.as_ptr(),
1622 i,
1623 );
1624 }
1625 if scope.mixer_proj {
1626 add_outer_grad(
1627 layer_grads.x_proj_w.as_mut_slice(),
1628 r + 2 * s,
1629 i,
1630 scratch.grad_proj.as_slice(),
1631 tr.conv_post.as_slice(),
1632 );
1633 if let Some(x_proj_b) = layer_grads.x_proj_b.as_mut() {
1634 add_vec_grad(x_proj_b.as_mut_slice(), scratch.grad_proj.as_slice());
1635 }
1636 }
1637
1638 for ch in 0..i {
1639 scratch.grad_conv_pre[ch] = scratch.grad_conv[ch]
1640 * silu_grad_from_sigmoid(tr.conv_pre[ch], tr.conv_sigmoid[ch]);
1641 }
1642
1643 for ch in 0..i {
1644 let g = scratch.grad_conv_pre[ch];
1645 let base = ch * k;
1646 let conv_future = &mut future_layer.conv_next.as_mut_slice()[base..base + k];
1647 let mut ring = tr.conv_pos_prev;
1648
1649 scratch.grad_xz[ch] += g * layer.conv_w[base];
1650 scratch.grad_xz[ch] += conv_future[tr.conv_pos_prev];
1651
1652 if scope.mixer_conv {
1653 scratch.grad_conv_w[base] += g * tr.xz[ch];
1654 if layer.conv_b.is_some() {
1655 scratch.grad_conv_b[ch] += g;
1656 }
1657 }
1658
1659 conv_future[tr.conv_pos_prev] = 0.0;
1660 for tap in 1..k {
1661 ring = if ring == 0 { k - 1 } else { ring - 1 };
1662 conv_future[ring] += g * layer.conv_w[base + tap];
1663 if scope.mixer_conv {
1664 scratch.grad_conv_w[base + tap] += g * tr.conv_prev[base + ring];
1665 }
1666 }
1667 }
1668 if scope.mixer_conv {
1669 add_vec_grad(
1670 layer_grads.conv_w.as_mut_slice(),
1671 scratch.grad_conv_w.as_slice(),
1672 );
1673 if let Some(conv_b) = layer_grads.conv_b.as_mut() {
1674 add_vec_grad(conv_b.as_mut_slice(), scratch.grad_conv_b.as_slice());
1675 }
1676 }
1677
1678 unsafe {
1679 kernel::gemv_t(
1680 layer.in_proj_w.as_ptr(),
1681 scratch.grad_xz.as_ptr(),
1682 scratch.grad_norm.as_mut_ptr(),
1683 2 * i,
1684 c,
1685 );
1686 }
1687 if scope.mixer_proj {
1688 add_outer_grad(
1689 layer_grads.in_proj_w.as_mut_slice(),
1690 2 * i,
1691 c,
1692 scratch.grad_xz.as_slice(),
1693 tr.norm.as_slice(),
1694 );
1695 if let Some(in_proj_b) = layer_grads.in_proj_b.as_mut() {
1696 add_vec_grad(in_proj_b.as_mut_slice(), scratch.grad_xz.as_slice());
1697 }
1698 }
1699
1700 rms_norm_backward(
1701 tr.h_in.as_slice(),
1702 layer.norm_w.as_slice(),
1703 scratch.grad_norm.as_slice(),
1704 self.cfg.layer_norm_eps,
1705 scratch.grad_h.as_mut_slice(),
1706 scratch.grad_out.as_mut_slice(),
1707 );
1708 if scope.layer_norm {
1709 add_vec_grad(
1710 layer_grads.norm_w.as_mut_slice(),
1711 scratch.grad_out.as_slice(),
1712 );
1713 if let Some(norm_b) = layer_grads.norm_b.as_mut() {
1714 add_vec_grad(norm_b.as_mut_slice(), scratch.grad_norm.as_slice());
1715 }
1716 }
1717
1718 for idx in 0..c {
1719 scratch.grad_h[idx] += scratch.grad_residual[idx];
1720 }
1721 }
1722
1723 if scope.embed {
1724 let tok = trace.token.min(self.cfg.vocab_size.saturating_sub(1));
1725 let row_off = tok * c;
1726 add_vec_grad(
1727 &mut grads.embeddings.as_mut_slice()[row_off..row_off + c],
1728 scratch.grad_h.as_slice(),
1729 );
1730 }
1731
1732 Ok(())
1733 }
1734
1735 #[allow(clippy::too_many_arguments)]
1736 pub fn online_train_segment_tbptt(
1738 &mut self,
1739 scratch: &mut ScratchBuffers,
1740 start_state: &State,
1741 steps: &[(u32, u8, Vec<f64>)],
1742 scope: TrainScopeMask,
1743 optimizer: OptimizerKind,
1744 lr: f32,
1745 clip: f32,
1746 replay_chunk: usize,
1747 adam_t: &mut usize,
1748 model_adam: Option<&mut FullAdamState>,
1749 out_bias: Option<&mut [f32]>,
1750 out_bias_adam_m: Option<&mut [f32]>,
1751 out_bias_adam_v: Option<&mut [f32]>,
1752 live_state_out: &mut State,
1753 ) -> Result<()> {
1754 if steps.is_empty() {
1755 *live_state_out = start_state.clone();
1756 return Ok(());
1757 }
1758
1759 let grad_scale = 1.0f32 / (steps.len() as f32);
1760 let chunk = replay_chunk.max(1).min(steps.len().max(1));
1761 let mut grads = self.new_full_grad_state();
1762 let mut recurrent = self.new_recurrent_grad_state();
1763 recurrent.zero();
1764 let mut bias_grad = out_bias.as_deref().map(|b| vec![0.0f32; b.len()]);
1765
1766 {
1767 let mut checkpoints = Vec::<State>::new();
1768 let mut checkpoint_state = start_state.clone();
1769 scratch.set_capture_train_trace(false);
1770 for chunk_start in (0..steps.len()).step_by(chunk) {
1771 checkpoints.push(checkpoint_state.clone());
1772 let chunk_end = (chunk_start + chunk).min(steps.len());
1773 for (input_token, _, _) in &steps[chunk_start..chunk_end] {
1774 let _ = self.forward(scratch, *input_token, &mut checkpoint_state);
1775 }
1776 }
1777
1778 for chunk_idx in (0..checkpoints.len()).rev() {
1779 let chunk_start = chunk_idx * chunk;
1780 let chunk_end = (chunk_start + chunk).min(steps.len());
1781 let mut state = checkpoints[chunk_idx].clone();
1782 let mut step_states = Vec::<State>::with_capacity(chunk_end - chunk_start + 1);
1783 let mut step_traces =
1784 Vec::<TokenTrainTrace>::with_capacity(chunk_end - chunk_start);
1785 step_states.push(state.clone());
1786
1787 for (input_token, _, _) in &steps[chunk_start..chunk_end] {
1788 scratch.set_capture_train_trace(true);
1789 let _ = self.forward(scratch, *input_token, &mut state);
1790 step_traces.push(TokenTrainTrace::from_scratch(scratch));
1791 step_states.push(state.clone());
1792 }
1793
1794 for local_idx in (0..step_traces.len()).rev() {
1795 let (_, target_symbol, pdf) = &steps[chunk_start + local_idx];
1796 self.accumulate_token_step_gradients(
1797 scratch,
1798 &step_traces[local_idx],
1799 &step_states[local_idx + 1],
1800 *target_symbol,
1801 pdf,
1802 grad_scale,
1803 scope,
1804 &mut grads,
1805 bias_grad.as_deref_mut(),
1806 &mut recurrent,
1807 )?;
1808 }
1809 }
1810 }
1811
1812 self.apply_full_gradients(
1813 &grads,
1814 scope,
1815 optimizer,
1816 lr,
1817 clip,
1818 adam_t,
1819 model_adam,
1820 out_bias,
1821 bias_grad.as_deref(),
1822 out_bias_adam_m,
1823 out_bias_adam_v,
1824 )?;
1825
1826 scratch.set_capture_train_trace(false);
1827 *live_state_out = start_state.clone();
1828 for (input_token, _, _) in steps {
1829 let _ = self.forward(scratch, *input_token, live_state_out);
1830 }
1831 Ok(())
1832 }
1833
1834 #[inline(never)]
1836 pub fn forward<'a>(
1837 &'a self,
1838 scratch: &'a mut ScratchBuffers,
1839 token: u32,
1840 state: &mut State,
1841 ) -> &'a [f32] {
1842 if scratch.capture_train_trace {
1843 self.forward_impl::<true>(scratch, token, state)
1844 } else {
1845 self.forward_impl::<false>(scratch, token, state)
1846 }
1847 }
1848
1849 fn forward_impl<'a, const CAPTURE: bool>(
1850 &'a self,
1851 scratch: &'a mut ScratchBuffers,
1852 token: u32,
1853 state: &mut State,
1854 ) -> &'a [f32] {
1855 let c = self.cfg.hidden_size;
1856 let i = self.cfg.inner_size;
1857 let s = self.cfg.state_size;
1858 let r = self.cfg.dt_rank;
1859
1860 let token_idx = (token as usize).min(self.cfg.vocab_size.saturating_sub(1));
1861 let emb_off = token_idx * c;
1862 if CAPTURE {
1863 scratch.train_token = token_idx;
1864 scratch.train_trace_valid = true;
1865 } else {
1866 scratch.train_trace_valid = false;
1867 }
1868 scratch
1869 .h
1870 .as_mut_slice()
1871 .copy_from_slice(&self.embeddings.as_slice()[emb_off..emb_off + c]);
1872
1873 for layer_idx in 0..self.cfg.num_layers {
1874 let layer = &self.layers[layer_idx];
1875 let st = &mut state.layers[layer_idx];
1876 if CAPTURE {
1877 let tr = &mut scratch.train_trace_layers[layer_idx];
1878 tr.h_in.as_mut_slice().copy_from_slice(scratch.h.as_slice());
1879 tr.ssm_prev
1880 .as_mut_slice()
1881 .copy_from_slice(st.ssm.as_slice());
1882 tr.conv_prev
1883 .as_mut_slice()
1884 .copy_from_slice(st.conv.as_slice());
1885 tr.conv_pos_prev = st.conv_pos;
1886 }
1887
1888 rms_norm(
1889 scratch.h.as_slice(),
1890 layer.norm_w.as_slice(),
1891 layer.norm_b.as_ref().map(Tensor1D::as_slice),
1892 self.cfg.layer_norm_eps,
1893 scratch.norm.as_mut_slice(),
1894 );
1895 if CAPTURE {
1896 let tr = &mut scratch.train_trace_layers[layer_idx];
1897 tr.norm
1898 .as_mut_slice()
1899 .copy_from_slice(scratch.norm.as_slice());
1900 }
1901
1902 unsafe {
1903 kernel::gemv(
1905 layer.in_proj_w.as_ptr(),
1906 scratch.norm.as_ptr(),
1907 scratch.xz.as_mut_ptr(),
1908 i * 2,
1909 c,
1910 );
1911 }
1912 if let Some(bias) = &layer.in_proj_b {
1913 for (dst, &b) in scratch.xz.as_mut_slice().iter_mut().zip(bias.as_slice()) {
1914 *dst += b;
1915 }
1916 }
1917 if CAPTURE {
1918 let tr = &mut scratch.train_trace_layers[layer_idx];
1919 tr.xz.as_mut_slice().copy_from_slice(scratch.xz.as_slice());
1920 }
1921
1922 depthwise_conv_step(
1923 &scratch.xz.as_slice()[0..i],
1924 &layer.conv_w,
1925 layer.conv_b.as_ref(),
1926 self.cfg.conv_kernel,
1927 st,
1928 scratch.conv.as_mut_slice(),
1929 );
1930 if CAPTURE {
1931 let tr = &mut scratch.train_trace_layers[layer_idx];
1932 tr.conv_pre
1933 .as_mut_slice()
1934 .copy_from_slice(scratch.conv.as_slice());
1935 }
1936
1937 if CAPTURE {
1938 let tr = &mut scratch.train_trace_layers[layer_idx];
1939 for idx in 0..i {
1940 let (post, sig) = silu_with_sigmoid(scratch.conv[idx]);
1941 scratch.conv[idx] = post;
1942 tr.conv_post[idx] = post;
1943 tr.conv_sigmoid[idx] = sig;
1944 }
1945 } else {
1946 for idx in 0..i {
1947 scratch.conv[idx] = silu(scratch.conv[idx]);
1948 }
1949 }
1950
1951 unsafe {
1952 kernel::gemv(
1954 layer.x_proj_w.as_ptr(),
1955 scratch.conv.as_ptr(),
1956 scratch.proj.as_mut_ptr(),
1957 r + 2 * s,
1958 i,
1959 );
1960 }
1961 if let Some(bias) = &layer.x_proj_b {
1962 for (dst, &b) in scratch.proj.as_mut_slice().iter_mut().zip(bias.as_slice()) {
1963 *dst += b;
1964 }
1965 }
1966 if CAPTURE {
1967 let tr = &mut scratch.train_trace_layers[layer_idx];
1968 tr.proj
1969 .as_mut_slice()
1970 .copy_from_slice(scratch.proj.as_slice());
1971 }
1972
1973 unsafe {
1974 kernel::gemv(
1976 layer.dt_proj_w.as_ptr(),
1977 scratch.proj.as_ptr(),
1978 scratch.dt.as_mut_ptr(),
1979 i,
1980 r,
1981 );
1982 }
1983 if CAPTURE {
1984 let tr = &mut scratch.train_trace_layers[layer_idx];
1985 tr.dt_raw
1986 .as_mut_slice()
1987 .copy_from_slice(scratch.dt.as_slice());
1988 }
1989
1990 let proj = scratch.proj.as_slice();
1991 let b_vec = &proj[r..r + s];
1992 let c_vec = &proj[r + s..r + 2 * s];
1993 let conv = scratch.conv.as_slice();
1994 let dt_raw = scratch.dt.as_slice();
1995 let xz = scratch.xz.as_slice();
1996 let d = layer.d.as_slice();
1997 let a = layer.a.as_slice();
1998 let dt_bias = layer.dt_proj_b.as_slice();
1999 let ssm = st.ssm.as_mut_slice();
2000 let b_ptr = b_vec.as_ptr();
2001 let c_ptr = c_vec.as_ptr();
2002 let a_ptr = a.as_ptr();
2003 let ssm_ptr = ssm.as_mut_ptr();
2004
2005 if s == 16 {
2006 for ch in 0..i {
2007 let x_ch = conv[ch];
2008 let dt_pre = dt_raw[ch] + dt_bias[ch];
2009 let gate_pre = xz[i + ch];
2010 let dt = softplus(dt_pre);
2011 let (gate, gate_sigmoid) = silu_with_sigmoid(gate_pre);
2012 let x_dt = x_ch * dt;
2013
2014 let ssm_row_off = ch * s;
2015 let row_a = unsafe { a_ptr.add(ssm_row_off) };
2016 let row_ssm = unsafe { ssm_ptr.add(ssm_row_off) };
2017 let trace_ptr = if CAPTURE {
2018 unsafe {
2019 scratch.train_trace_layers[layer_idx]
2020 .d_a
2021 .as_mut_ptr()
2022 .add(ssm_row_off)
2023 }
2024 } else {
2025 std::ptr::null_mut()
2026 };
2027 let mut y = d[ch] * x_ch;
2028 y += unsafe {
2029 selective_scan_state16::<CAPTURE>(
2030 row_a, row_ssm, dt, x_dt, b_ptr, c_ptr, trace_ptr,
2031 )
2032 };
2033 if CAPTURE {
2034 let tr = &mut scratch.train_trace_layers[layer_idx];
2035 tr.dt[ch] = dt;
2036 tr.gate[ch] = gate;
2037 tr.gate_sigmoid[ch] = gate_sigmoid;
2038 tr.y_pre[ch] = y;
2039 }
2040 scratch.y[ch] = y * gate;
2041 if CAPTURE {
2042 scratch.train_trace_layers[layer_idx].y[ch] = scratch.y[ch];
2043 }
2044 }
2045 } else {
2046 for ch in 0..i {
2047 let x_ch = conv[ch];
2048 let dt_pre = dt_raw[ch] + dt_bias[ch];
2049 let gate_pre = xz[i + ch];
2050 let dt = softplus(dt_pre);
2051 let (gate, gate_sigmoid) = silu_with_sigmoid(gate_pre);
2052 let x_dt = x_ch * dt;
2053
2054 let mut y = d[ch] * x_ch;
2055 let ssm_row_off = ch * s;
2056 let row_a = unsafe { a_ptr.add(ssm_row_off) };
2057 let row_ssm = unsafe { ssm_ptr.add(ssm_row_off) };
2058 let mut j = 0usize;
2059 while j < s {
2060 let prev = unsafe { *row_ssm.add(j) };
2061 let d_a = (dt * unsafe { *row_a.add(j) }).exp();
2062 if CAPTURE {
2063 scratch.train_trace_layers[layer_idx].d_a[ssm_row_off + j] = d_a;
2064 }
2065 let next = prev * d_a + x_dt * unsafe { *b_ptr.add(j) };
2066 unsafe { *row_ssm.add(j) = next };
2067 y += next * unsafe { *c_ptr.add(j) };
2068 j += 1;
2069 }
2070 if CAPTURE {
2071 let tr = &mut scratch.train_trace_layers[layer_idx];
2072 tr.dt[ch] = dt;
2073 tr.gate[ch] = gate;
2074 tr.gate_sigmoid[ch] = gate_sigmoid;
2075 tr.y_pre[ch] = y;
2076 }
2077 scratch.y[ch] = y * gate;
2078 if CAPTURE {
2079 scratch.train_trace_layers[layer_idx].y[ch] = scratch.y[ch];
2080 }
2081 }
2082 }
2083
2084 unsafe {
2085 kernel::gemv(
2087 layer.out_proj_w.as_ptr(),
2088 scratch.y.as_ptr(),
2089 scratch.out.as_mut_ptr(),
2090 c,
2091 i,
2092 );
2093 }
2094 if let Some(bias) = &layer.out_proj_b {
2095 for (dst, &b) in scratch.out.as_mut_slice().iter_mut().zip(bias.as_slice()) {
2096 *dst += b;
2097 }
2098 }
2099 if CAPTURE {
2100 let tr = &mut scratch.train_trace_layers[layer_idx];
2101 tr.out
2102 .as_mut_slice()
2103 .copy_from_slice(scratch.out.as_slice());
2104 }
2105
2106 unsafe {
2107 kernel::add_inplace(scratch.h.as_mut_ptr(), scratch.out.as_ptr(), c);
2109 }
2110 }
2111 if CAPTURE {
2112 scratch
2113 .train_h_final
2114 .as_mut_slice()
2115 .copy_from_slice(scratch.h.as_slice());
2116 }
2117
2118 rms_norm(
2119 scratch.h.as_slice(),
2120 self.final_norm_w.as_slice(),
2121 self.final_norm_b.as_ref().map(Tensor1D::as_slice),
2122 self.cfg.layer_norm_eps,
2123 scratch.norm.as_mut_slice(),
2124 );
2125
2126 unsafe {
2127 kernel::gemv(
2129 self.lm_head.as_ptr(),
2130 scratch.norm.as_ptr(),
2131 scratch.logits.as_mut_ptr(),
2132 self.cfg.vocab_size,
2133 c,
2134 );
2135 }
2136 if let Some(bias) = &self.lm_head_b {
2137 for (dst, &b) in scratch
2138 .logits
2139 .as_mut_slice()
2140 .iter_mut()
2141 .zip(bias.as_slice())
2142 {
2143 *dst += b;
2144 }
2145 }
2146
2147 scratch.logits.as_slice()
2148 }
2149
2150 #[allow(clippy::too_many_arguments)]
2155 #[allow(clippy::needless_range_loop)]
2156 pub fn online_train_step_bptt1(
2157 &mut self,
2158 scratch: &mut ScratchBuffers,
2159 state: &State,
2160 symbol: u8,
2161 pdf: &[f64],
2162 scope: TrainScopeMask,
2163 optimizer: OptimizerKind,
2164 lr: f32,
2165 clip: f32,
2166 adam_t: &mut usize,
2167 model_adam: Option<&mut FullAdamState>,
2168 out_bias: Option<&mut [f32]>,
2169 out_bias_adam_m: Option<&mut [f32]>,
2170 out_bias_adam_v: Option<&mut [f32]>,
2171 ) -> Result<()> {
2172 if !scope.trains_model_params() && !scope.bias {
2173 return Ok(());
2174 }
2175 let needs_backprop = scope.embed
2176 || scope.layer_norm
2177 || scope.mixer_conv
2178 || scope.mixer_ssm
2179 || scope.mixer_proj;
2180 if needs_backprop && !scratch.train_trace_valid {
2181 bail!("mamba full training trace is missing; run one forward step first");
2182 }
2183 let c = self.cfg.hidden_size;
2184 let i = self.cfg.inner_size;
2185 let s = self.cfg.state_size;
2186 let r = self.cfg.dt_rank;
2187 let v = self.cfg.vocab_size.min(pdf.len());
2188 if v == 0 {
2189 return Ok(());
2190 }
2191
2192 let mut adam_cfg = None::<AdamStep>;
2193 let mut model_adam = model_adam;
2194 if matches!(optimizer, OptimizerKind::Adam) {
2195 *adam_t = adam_t.saturating_add(1);
2196 let t = (*adam_t).max(1) as i32;
2197 let b1 = 0.9f32;
2198 let b2 = 0.999f32;
2199 adam_cfg = Some(AdamStep {
2200 lr,
2201 clip,
2202 b1,
2203 b2,
2204 eps: 1e-8,
2205 bias_corr1: 1.0 - b1.powi(t),
2206 bias_corr2: 1.0 - b2.powi(t),
2207 });
2208 if scope.trains_model_params() && model_adam.is_none() {
2209 bail!("mamba Adam full-training state is missing");
2210 }
2211 }
2212
2213 scratch.grad_logits.zero();
2215 for tok in 0..v {
2216 let p = pdf[tok].clamp(1e-12, 1.0) as f32;
2217 let target = if tok == symbol as usize { 1.0 } else { 0.0 };
2218 let mut g = target - p;
2219 if clip > 0.0 {
2220 g = g.clamp(-clip, clip);
2221 }
2222 scratch.grad_logits[tok] = g;
2223 }
2224
2225 if scope.bias
2226 && let Some(bias) = out_bias
2227 {
2228 let n = bias.len().min(v);
2229 let grad = &scratch.grad_logits.as_slice()[..n];
2230 match optimizer {
2231 OptimizerKind::Sgd => {
2232 for idx in 0..n {
2233 bias[idx] += lr * grad[idx];
2234 }
2235 }
2236 OptimizerKind::Adam => {
2237 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2238 let Some(m) = out_bias_adam_m else {
2239 bail!("mamba Adam output-bias moments are missing");
2240 };
2241 let Some(vv) = out_bias_adam_v else {
2242 bail!("mamba Adam output-bias moments are missing");
2243 };
2244 if m.len() < n || vv.len() < n {
2245 bail!("mamba Adam output-bias moments have invalid shape");
2246 }
2247 for idx in 0..n {
2248 let g = grad[idx];
2249 m[idx] = cfg.b1 * m[idx] + (1.0 - cfg.b1) * g;
2250 vv[idx] = cfg.b2 * vv[idx] + (1.0 - cfg.b2) * g * g;
2251 let m_hat = m[idx] / cfg.bias_corr1;
2252 let v_hat = vv[idx] / cfg.bias_corr2;
2253 bias[idx] += cfg.lr * m_hat / (v_hat.sqrt() + cfg.eps);
2254 }
2255 }
2256 }
2257 }
2258
2259 scratch.grad_h.zero();
2262 let norm_in = scratch.norm.as_slice();
2263 if scope.head {
2264 match optimizer {
2265 OptimizerKind::Sgd => {
2266 let head = self.lm_head.as_mut_slice();
2267 let norm_ptr = norm_in.as_ptr();
2268 let grad_h_ptr = scratch.grad_h.as_mut_slice().as_mut_ptr();
2269 for tok in 0..v {
2270 let g = scratch.grad_logits[tok];
2271 let row_off = tok * c;
2272 let mut j = 0usize;
2273 unsafe {
2274 let g8 = f32x8::splat(g);
2275 let lr8 = f32x8::splat(lr);
2276 while j + 8 <= c {
2277 let idx = row_off + j;
2278 let wv = head.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
2279 let nv = norm_ptr.add(j).cast::<f32x8>().read_unaligned();
2280 let ghv = grad_h_ptr.add(j).cast::<f32x8>().read_unaligned();
2281 grad_h_ptr
2282 .add(j)
2283 .cast::<f32x8>()
2284 .write_unaligned(ghv + wv * g8);
2285 head.as_mut_ptr()
2286 .add(idx)
2287 .cast::<f32x8>()
2288 .write_unaligned(wv + (g8 * nv) * lr8);
2289 j += 8;
2290 }
2291 }
2292 while j < c {
2293 let idx = row_off + j;
2294 let w_old = head[idx];
2295 scratch.grad_h[j] += w_old * g;
2296 head[idx] = w_old + lr * g * norm_in[j];
2297 j += 1;
2298 }
2299 }
2300 if let Some(b) = self.lm_head_b.as_mut() {
2301 for tok in 0..v.min(b.len()) {
2302 b[tok] += lr * scratch.grad_logits[tok];
2303 }
2304 }
2305 }
2306 OptimizerKind::Adam => {
2307 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2308 let adam = model_adam.as_mut().expect("adam state exists");
2309 let head = self.lm_head.as_mut_slice();
2310 let hm = adam.lm_head.m.as_mut_slice();
2311 let hv = adam.lm_head.v.as_mut_slice();
2312 let norm_ptr = norm_in.as_ptr();
2313 let grad_h_ptr = scratch.grad_h.as_mut_slice().as_mut_ptr();
2314 let b1 = f32x8::splat(cfg.b1);
2315 let b2 = f32x8::splat(cfg.b2);
2316 let one_b1 = f32x8::splat(1.0 - cfg.b1);
2317 let one_b2 = f32x8::splat(1.0 - cfg.b2);
2318 let inv_bc1 = f32x8::splat(1.0 / cfg.bias_corr1);
2319 let inv_bc2 = f32x8::splat(1.0 / cfg.bias_corr2);
2320 let eps = f32x8::splat(cfg.eps);
2321 let lr8 = f32x8::splat(cfg.lr);
2322 for tok in 0..v {
2323 let g = scratch.grad_logits[tok];
2324 let row_off = tok * c;
2325 let mut j = 0usize;
2326 unsafe {
2327 let g8 = f32x8::splat(g);
2328 while j + 8 <= c {
2329 let idx = row_off + j;
2330 let wv = head.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
2331 let nv = norm_ptr.add(j).cast::<f32x8>().read_unaligned();
2332 let ghv = grad_h_ptr.add(j).cast::<f32x8>().read_unaligned();
2333 grad_h_ptr
2334 .add(j)
2335 .cast::<f32x8>()
2336 .write_unaligned(ghv + wv * g8);
2337
2338 let gg = g8 * nv;
2339 let hm_old = hm.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
2340 let hv_old = hv.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
2341 let m = hm_old * b1 + gg * one_b1;
2342 let vv = hv_old * b2 + (gg * gg) * one_b2;
2343 hm.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(m);
2344 hv.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv);
2345 let upd = ((m * inv_bc1) / ((vv * inv_bc2).sqrt() + eps)) * lr8;
2346 head.as_mut_ptr()
2347 .add(idx)
2348 .cast::<f32x8>()
2349 .write_unaligned(wv + upd);
2350 j += 8;
2351 }
2352 }
2353 while j < c {
2354 let idx = row_off + j;
2355 let w_old = head[idx];
2356 scratch.grad_h[j] += w_old * g;
2357 let gg = g * norm_in[j];
2358 let m = cfg.b1 * hm[idx] + (1.0 - cfg.b1) * gg;
2359 let vv = cfg.b2 * hv[idx] + (1.0 - cfg.b2) * gg * gg;
2360 hm[idx] = m;
2361 hv[idx] = vv;
2362 let m_hat = m / cfg.bias_corr1;
2363 let v_hat = vv / cfg.bias_corr2;
2364 head[idx] = w_old + cfg.lr * m_hat / (v_hat.sqrt() + cfg.eps);
2365 j += 1;
2366 }
2367 }
2368 if let (Some(b), Some(bm)) = (self.lm_head_b.as_mut(), adam.lm_head_b.as_mut())
2369 {
2370 let bm_m = bm.m.as_mut_slice();
2371 let bm_v = bm.v.as_mut_slice();
2372 for tok in 0..v.min(b.len()) {
2373 let g = scratch.grad_logits[tok];
2374 let m = cfg.b1 * bm_m[tok] + (1.0 - cfg.b1) * g;
2375 let vv = cfg.b2 * bm_v[tok] + (1.0 - cfg.b2) * g * g;
2376 bm_m[tok] = m;
2377 bm_v[tok] = vv;
2378 let m_hat = m / cfg.bias_corr1;
2379 let v_hat = vv / cfg.bias_corr2;
2380 b[tok] += cfg.lr * m_hat / (v_hat.sqrt() + cfg.eps);
2381 }
2382 }
2383 }
2384 }
2385 } else {
2386 let head = self.lm_head.as_slice();
2387 for tok in 0..v {
2388 let g = scratch.grad_logits[tok];
2389 let row_off = tok * c;
2390 for j in 0..c {
2391 scratch.grad_h[j] += head[row_off + j] * g;
2392 }
2393 }
2394 }
2395
2396 if !needs_backprop {
2397 return Ok(());
2398 }
2399
2400 rms_norm_backward(
2402 scratch.train_h_final.as_slice(),
2403 self.final_norm_w.as_slice(),
2404 scratch.grad_h.as_slice(),
2405 self.cfg.layer_norm_eps,
2406 scratch.grad_norm.as_mut_slice(),
2407 scratch.grad_out.as_mut_slice(),
2408 );
2409 if scope.layer_norm {
2410 match optimizer {
2411 OptimizerKind::Sgd => {
2412 for idx in 0..c {
2413 self.final_norm_w[idx] += lr * scratch.grad_out[idx];
2414 }
2415 if let Some(b) = self.final_norm_b.as_mut() {
2416 for idx in 0..c.min(b.len()) {
2417 b[idx] += lr * scratch.grad_h[idx];
2418 }
2419 }
2420 }
2421 OptimizerKind::Adam => {
2422 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2423 let adam = model_adam.as_mut().expect("adam state exists");
2424 apply_adam_vec_update(
2425 self.final_norm_w.as_mut_slice(),
2426 scratch.grad_out.as_slice(),
2427 &mut adam.final_norm_w,
2428 cfg,
2429 );
2430 if let (Some(b), Some(bm)) =
2431 (self.final_norm_b.as_mut(), adam.final_norm_b.as_mut())
2432 {
2433 apply_adam_vec_update(b.as_mut_slice(), scratch.grad_h.as_slice(), bm, cfg);
2434 }
2435 }
2436 }
2437 }
2438 scratch
2439 .grad_h
2440 .as_mut_slice()
2441 .copy_from_slice(scratch.grad_norm.as_slice());
2442
2443 for layer_idx in (0..self.cfg.num_layers).rev() {
2445 let tr = &scratch.train_trace_layers[layer_idx];
2446 let st_new = &state.layers[layer_idx];
2447 let layer = &mut self.layers[layer_idx];
2448
2449 scratch
2450 .grad_out
2451 .as_mut_slice()
2452 .copy_from_slice(scratch.grad_h.as_slice());
2453
2454 unsafe {
2456 kernel::gemv_t(
2457 layer.out_proj_w.as_ptr(),
2458 scratch.grad_out.as_ptr(),
2459 scratch.grad_y.as_mut_ptr(),
2460 c,
2461 i,
2462 );
2463 }
2464
2465 if scope.mixer_proj {
2466 match optimizer {
2467 OptimizerKind::Sgd => {
2468 for row in 0..c {
2469 let g = scratch.grad_out[row];
2470 let off = row * i;
2471 for col in 0..i {
2472 layer.out_proj_w[off + col] += lr * g * tr.y[col];
2473 }
2474 }
2475 if let Some(b) = layer.out_proj_b.as_mut() {
2476 for row in 0..c.min(b.len()) {
2477 b[row] += lr * scratch.grad_out[row];
2478 }
2479 }
2480 }
2481 OptimizerKind::Adam => {
2482 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2483 let adam_layer =
2484 &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2485 apply_adam_outer_update(
2486 layer.out_proj_w.as_mut_slice(),
2487 c,
2488 i,
2489 scratch.grad_out.as_slice(),
2490 tr.y.as_slice(),
2491 &mut adam_layer.out_proj_w,
2492 cfg,
2493 );
2494 if let (Some(b), Some(bm)) =
2495 (layer.out_proj_b.as_mut(), adam_layer.out_proj_b.as_mut())
2496 {
2497 apply_adam_vec_update(
2498 b.as_mut_slice(),
2499 scratch.grad_out.as_slice(),
2500 bm,
2501 cfg,
2502 );
2503 }
2504 }
2505 }
2506 }
2507
2508 scratch
2509 .grad_residual
2510 .as_mut_slice()
2511 .copy_from_slice(scratch.grad_out.as_slice());
2512 scratch.grad_xz.zero();
2513 scratch.grad_b.zero();
2514 scratch.grad_c.zero();
2515
2516 for ch in 0..i {
2518 let g_y = scratch.grad_y[ch];
2519 let gate = tr.gate[ch];
2520 let y_pre = tr.y_pre[ch];
2521 let g_y_pre = g_y * gate;
2522 let g_gate = g_y * y_pre;
2523 scratch.grad_xz[i + ch] =
2524 g_gate * silu_grad_from_sigmoid(tr.xz[i + ch], tr.gate_sigmoid[ch]);
2525
2526 let conv = tr.conv_post[ch];
2527 let dt = tr.dt[ch];
2528 let xdt = conv * dt;
2529 let mut g_xdt = 0.0f32;
2530 let mut g_dt = 0.0f32;
2531
2532 let mut g_conv = g_y_pre * layer.d[ch];
2533 if scope.mixer_ssm {
2534 scratch.grad_ssm_d[ch] = g_y_pre * conv;
2535 }
2536
2537 let row = ch * s;
2538 for j in 0..s {
2539 let idx = row + j;
2540 let c_j = tr.proj[r + s + j];
2541 let b_j = tr.proj[r + j];
2542 let s_prev = tr.ssm_prev[idx];
2543 let s_new = st_new.ssm[idx];
2544 let a_ij = layer.a[idx];
2545
2546 let g_ssm_new = g_y_pre * c_j;
2547 scratch.grad_c[j] += g_y_pre * s_new;
2548 g_xdt += g_ssm_new * b_j;
2549 scratch.grad_b[j] += g_ssm_new * xdt;
2550
2551 let d_a = tr.d_a[idx];
2552 let g_d_a = g_ssm_new * s_prev;
2553 g_dt += g_d_a * d_a * a_ij;
2554 if scope.mixer_ssm {
2555 scratch.grad_ssm_a[idx] = g_d_a * d_a * dt;
2556 }
2557 }
2558
2559 g_conv += g_xdt * dt;
2560 g_dt += g_xdt * conv;
2561 let dt_pre = tr.dt_raw[ch] + layer.dt_proj_b[ch];
2562 scratch.grad_dt_raw[ch] = g_dt * sigmoid(dt_pre);
2563 scratch.grad_conv[ch] = g_conv;
2564 }
2565
2566 if scope.mixer_ssm {
2567 match optimizer {
2568 OptimizerKind::Sgd => {
2569 if clip > 0.0 {
2570 for idx in 0..i {
2571 layer.d[idx] += lr * scratch.grad_ssm_d[idx].clamp(-clip, clip);
2572 }
2573 for idx in 0..(i * s) {
2574 let g_log =
2575 (scratch.grad_ssm_a[idx] * layer.a[idx]).clamp(-clip, clip);
2576 let new_log = layer.a_log[idx] + lr * g_log;
2577 layer.a_log[idx] = new_log;
2578 layer.a[idx] = -new_log.exp();
2579 }
2580 } else {
2581 for idx in 0..i {
2582 layer.d[idx] += lr * scratch.grad_ssm_d[idx];
2583 }
2584 for idx in 0..(i * s) {
2585 let g_log = scratch.grad_ssm_a[idx] * layer.a[idx];
2586 let new_log = layer.a_log[idx] + lr * g_log;
2587 layer.a_log[idx] = new_log;
2588 layer.a[idx] = -new_log.exp();
2589 }
2590 }
2591 }
2592 OptimizerKind::Adam => {
2593 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2594 let adam_layer =
2595 &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2596 for idx in 0..(i * s) {
2597 scratch.grad_ssm_a_log[idx] = scratch.grad_ssm_a[idx] * layer.a[idx];
2598 }
2599 apply_adam_vec_update(
2600 layer.d.as_mut_slice(),
2601 scratch.grad_ssm_d.as_slice(),
2602 &mut adam_layer.d,
2603 cfg,
2604 );
2605 apply_adam_vec_update_and_sync_neg_exp(
2606 layer.a_log.as_mut_slice(),
2607 layer.a.as_mut_slice(),
2608 scratch.grad_ssm_a_log.as_slice(),
2609 &mut adam_layer.a,
2610 cfg,
2611 );
2612 }
2613 }
2614 }
2615
2616 unsafe {
2618 kernel::gemv_t(
2619 layer.dt_proj_w.as_ptr(),
2620 scratch.grad_dt_raw.as_ptr(),
2621 scratch.grad_u.as_mut_ptr(),
2622 i,
2623 r,
2624 );
2625 }
2626 if scope.mixer_proj {
2627 match optimizer {
2628 OptimizerKind::Sgd => {
2629 for ch in 0..i {
2630 let g = scratch.grad_dt_raw[ch];
2631 let off = ch * r;
2632 for kk in 0..r {
2633 layer.dt_proj_w[off + kk] += lr * g * tr.proj[kk];
2634 }
2635 layer.dt_proj_b[ch] += lr * g;
2636 }
2637 }
2638 OptimizerKind::Adam => {
2639 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2640 let adam_layer =
2641 &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2642 apply_adam_outer_update(
2643 layer.dt_proj_w.as_mut_slice(),
2644 i,
2645 r,
2646 scratch.grad_dt_raw.as_slice(),
2647 &tr.proj.as_slice()[..r],
2648 &mut adam_layer.dt_proj_w,
2649 cfg,
2650 );
2651 apply_adam_vec_update(
2652 layer.dt_proj_b.as_mut_slice(),
2653 scratch.grad_dt_raw.as_slice(),
2654 &mut adam_layer.dt_proj_b,
2655 cfg,
2656 );
2657 }
2658 }
2659 }
2660
2661 for kk in 0..r {
2662 scratch.grad_proj[kk] = scratch.grad_u[kk];
2663 }
2664 for j in 0..s {
2665 scratch.grad_proj[r + j] = scratch.grad_b[j];
2666 scratch.grad_proj[r + s + j] = scratch.grad_c[j];
2667 }
2668
2669 unsafe {
2671 kernel::gemv_t(
2672 layer.x_proj_w.as_ptr(),
2673 scratch.grad_proj.as_ptr(),
2674 scratch.grad_conv_pre.as_mut_ptr(),
2675 r + 2 * s,
2676 i,
2677 );
2678 kernel::add_inplace(
2679 scratch.grad_conv.as_mut_ptr(),
2680 scratch.grad_conv_pre.as_ptr(),
2681 i,
2682 );
2683 }
2684 if scope.mixer_proj {
2685 match optimizer {
2686 OptimizerKind::Sgd => {
2687 for row in 0..(r + 2 * s) {
2688 let g = scratch.grad_proj[row];
2689 let off = row * i;
2690 for col in 0..i {
2691 layer.x_proj_w[off + col] += lr * g * tr.conv_post[col];
2692 }
2693 }
2694 if let Some(b) = layer.x_proj_b.as_mut() {
2695 for row in 0..(r + 2 * s).min(b.len()) {
2696 b[row] += lr * scratch.grad_proj[row];
2697 }
2698 }
2699 }
2700 OptimizerKind::Adam => {
2701 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2702 let adam_layer =
2703 &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2704 apply_adam_outer_update(
2705 layer.x_proj_w.as_mut_slice(),
2706 r + 2 * s,
2707 i,
2708 scratch.grad_proj.as_slice(),
2709 tr.conv_post.as_slice(),
2710 &mut adam_layer.x_proj_w,
2711 cfg,
2712 );
2713 if let (Some(b), Some(bm)) =
2714 (layer.x_proj_b.as_mut(), adam_layer.x_proj_b.as_mut())
2715 {
2716 apply_adam_vec_update(
2717 b.as_mut_slice(),
2718 scratch.grad_proj.as_slice(),
2719 bm,
2720 cfg,
2721 );
2722 }
2723 }
2724 }
2725 }
2726
2727 for ch in 0..i {
2729 scratch.grad_conv_pre[ch] = scratch.grad_conv[ch]
2730 * silu_grad_from_sigmoid(tr.conv_pre[ch], tr.conv_sigmoid[ch]);
2731 }
2732
2733 for ch in 0..i {
2734 let g = scratch.grad_conv_pre[ch];
2735 let base = ch * self.cfg.conv_kernel;
2736 let w0 = layer.conv_w[base];
2737 scratch.grad_xz[ch] += g * w0;
2738 if scope.mixer_conv && self.cfg.conv_kernel == 4 {
2739 let vals = match tr.conv_pos_prev {
2740 0 => [
2741 tr.xz[ch],
2742 tr.conv_prev[base + 3],
2743 tr.conv_prev[base + 2],
2744 tr.conv_prev[base + 1],
2745 ],
2746 1 => [
2747 tr.xz[ch],
2748 tr.conv_prev[base],
2749 tr.conv_prev[base + 3],
2750 tr.conv_prev[base + 2],
2751 ],
2752 2 => [
2753 tr.xz[ch],
2754 tr.conv_prev[base + 1],
2755 tr.conv_prev[base],
2756 tr.conv_prev[base + 3],
2757 ],
2758 _ => [
2759 tr.xz[ch],
2760 tr.conv_prev[base + 2],
2761 tr.conv_prev[base + 1],
2762 tr.conv_prev[base],
2763 ],
2764 };
2765 scratch.grad_conv_w[base] = g * vals[0];
2766 scratch.grad_conv_w[base + 1] = g * vals[1];
2767 scratch.grad_conv_w[base + 2] = g * vals[2];
2768 scratch.grad_conv_w[base + 3] = g * vals[3];
2769 } else {
2770 let mut ring = tr.conv_pos_prev;
2771 for tap in 0..self.cfg.conv_kernel {
2772 let val = if ring == tr.conv_pos_prev {
2773 tr.xz[ch]
2774 } else {
2775 tr.conv_prev[base + ring]
2776 };
2777 if scope.mixer_conv {
2778 scratch.grad_conv_w[base + tap] = g * val;
2779 }
2780 ring = if ring == 0 {
2781 self.cfg.conv_kernel - 1
2782 } else {
2783 ring - 1
2784 };
2785 }
2786 }
2787 if scope.mixer_conv && layer.conv_b.is_some() {
2788 scratch.grad_conv_b[ch] = g;
2789 }
2790 }
2791
2792 if scope.mixer_conv {
2793 match optimizer {
2794 OptimizerKind::Sgd => {
2795 if clip > 0.0 {
2796 for idx in 0..layer.conv_w.len() {
2797 layer.conv_w[idx] +=
2798 lr * scratch.grad_conv_w[idx].clamp(-clip, clip);
2799 }
2800 } else {
2801 for idx in 0..layer.conv_w.len() {
2802 layer.conv_w[idx] += lr * scratch.grad_conv_w[idx];
2803 }
2804 }
2805 if let Some(bias) = layer.conv_b.as_mut() {
2806 if clip > 0.0 {
2807 for idx in 0..bias.len().min(i) {
2808 bias[idx] += lr * scratch.grad_conv_b[idx].clamp(-clip, clip);
2809 }
2810 } else {
2811 for idx in 0..bias.len().min(i) {
2812 bias[idx] += lr * scratch.grad_conv_b[idx];
2813 }
2814 }
2815 }
2816 }
2817 OptimizerKind::Adam => {
2818 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2819 let adam_layer =
2820 &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2821 apply_adam_vec_update(
2822 layer.conv_w.as_mut_slice(),
2823 scratch.grad_conv_w.as_slice(),
2824 &mut adam_layer.conv_w,
2825 cfg,
2826 );
2827 if let (Some(bias), Some(bm)) =
2828 (layer.conv_b.as_mut(), adam_layer.conv_b.as_mut())
2829 {
2830 apply_adam_vec_update(
2831 bias.as_mut_slice(),
2832 scratch.grad_conv_b.as_slice(),
2833 bm,
2834 cfg,
2835 );
2836 }
2837 }
2838 }
2839 }
2840
2841 unsafe {
2843 kernel::gemv_t(
2844 layer.in_proj_w.as_ptr(),
2845 scratch.grad_xz.as_ptr(),
2846 scratch.grad_norm.as_mut_ptr(),
2847 2 * i,
2848 c,
2849 );
2850 }
2851 if scope.mixer_proj {
2852 match optimizer {
2853 OptimizerKind::Sgd => {
2854 for row in 0..(2 * i) {
2855 let g = scratch.grad_xz[row];
2856 let off = row * c;
2857 for col in 0..c {
2858 layer.in_proj_w[off + col] += lr * g * tr.norm[col];
2859 }
2860 }
2861 if let Some(b) = layer.in_proj_b.as_mut() {
2862 for row in 0..(2 * i).min(b.len()) {
2863 b[row] += lr * scratch.grad_xz[row];
2864 }
2865 }
2866 }
2867 OptimizerKind::Adam => {
2868 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2869 let adam_layer =
2870 &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2871 apply_adam_outer_update(
2872 layer.in_proj_w.as_mut_slice(),
2873 2 * i,
2874 c,
2875 scratch.grad_xz.as_slice(),
2876 tr.norm.as_slice(),
2877 &mut adam_layer.in_proj_w,
2878 cfg,
2879 );
2880 if let (Some(b), Some(bm)) =
2881 (layer.in_proj_b.as_mut(), adam_layer.in_proj_b.as_mut())
2882 {
2883 apply_adam_vec_update(
2884 b.as_mut_slice(),
2885 scratch.grad_xz.as_slice(),
2886 bm,
2887 cfg,
2888 );
2889 }
2890 }
2891 }
2892 }
2893
2894 rms_norm_backward(
2896 tr.h_in.as_slice(),
2897 layer.norm_w.as_slice(),
2898 scratch.grad_norm.as_slice(),
2899 self.cfg.layer_norm_eps,
2900 scratch.grad_h.as_mut_slice(),
2901 scratch.grad_out.as_mut_slice(),
2902 );
2903 if scope.layer_norm {
2904 match optimizer {
2905 OptimizerKind::Sgd => {
2906 for idx in 0..c {
2907 layer.norm_w[idx] += lr * scratch.grad_out[idx];
2908 }
2909 if let Some(b) = layer.norm_b.as_mut() {
2910 for idx in 0..c.min(b.len()) {
2911 b[idx] += lr * scratch.grad_norm[idx];
2912 }
2913 }
2914 }
2915 OptimizerKind::Adam => {
2916 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2917 let adam_layer =
2918 &mut model_adam.as_mut().expect("adam state exists").layers[layer_idx];
2919 apply_adam_vec_update(
2920 layer.norm_w.as_mut_slice(),
2921 scratch.grad_out.as_slice(),
2922 &mut adam_layer.norm_w,
2923 cfg,
2924 );
2925 if let (Some(b), Some(bm)) =
2926 (layer.norm_b.as_mut(), adam_layer.norm_b.as_mut())
2927 {
2928 apply_adam_vec_update(
2929 b.as_mut_slice(),
2930 scratch.grad_norm.as_slice(),
2931 bm,
2932 cfg,
2933 );
2934 }
2935 }
2936 }
2937 }
2938
2939 for idx in 0..c {
2940 scratch.grad_h[idx] += scratch.grad_residual[idx];
2941 }
2942 }
2943
2944 if scope.embed {
2945 let tok = scratch
2946 .train_token
2947 .min(self.cfg.vocab_size.saturating_sub(1));
2948 let row_off = tok * c;
2949 match optimizer {
2950 OptimizerKind::Sgd => {
2951 for j in 0..c {
2952 let g = if clip > 0.0 {
2953 scratch.grad_h[j].clamp(-clip, clip)
2954 } else {
2955 scratch.grad_h[j]
2956 };
2957 self.embeddings[row_off + j] += lr * g;
2958 }
2959 }
2960 OptimizerKind::Adam => {
2961 let cfg = adam_cfg.as_ref().expect("adam cfg initialized");
2962 let adam = model_adam.as_mut().expect("adam state exists");
2963 let pm = adam.embeddings.m.as_mut_slice();
2964 let pv = adam.embeddings.v.as_mut_slice();
2965 let emb = self.embeddings.as_mut_slice();
2966 for j in 0..c {
2967 let idx = row_off + j;
2968 let g = scratch.grad_h[j];
2969 pm[idx] = cfg.b1 * pm[idx] + (1.0 - cfg.b1) * g;
2970 pv[idx] = cfg.b2 * pv[idx] + (1.0 - cfg.b2) * g * g;
2971 let m_hat = pm[idx] / cfg.bias_corr1;
2972 let v_hat = pv[idx] / cfg.bias_corr2;
2973 emb[idx] += cfg.lr * m_hat / (v_hat.sqrt() + cfg.eps);
2974 }
2975 }
2976 }
2977 }
2978
2979 Ok(())
2980 }
2981
2982 fn load_native(weights: &Weights) -> Result<Self> {
2983 let emb = weights.require("model.embeddings.weight")?;
2984 if emb.shape().len() != 2 {
2985 bail!("model.embeddings.weight must be rank-2");
2986 }
2987 let vocab_size = emb.shape()[0];
2988 let hidden_size = emb.shape()[1];
2989
2990 let num_layers = count_layers(weights, "model.layers.", "mixer.in_proj.weight")?;
2991 if num_layers == 0 {
2992 bail!("no Mamba layers found in native checkpoint");
2993 }
2994
2995 let first_in = weights.require("model.layers.0.mixer.in_proj.weight")?;
2996 let first_conv = weights.require("model.layers.0.mixer.conv1d.weight")?;
2997 let first_a = weights.require("model.layers.0.mixer.A_log")?;
2998 let first_dt = weights.require("model.layers.0.mixer.dt_proj.weight")?;
2999
3000 let inner_size =
3001 infer_in_proj_inner(first_in, hidden_size, "model.layers.0.mixer.in_proj.weight")?;
3002 let conv_kernel =
3003 infer_conv_kernel(first_conv, inner_size, "model.layers.0.mixer.conv1d.weight")?;
3004 let state_size = infer_state_size(first_a, inner_size, "model.layers.0.mixer.A_log")?;
3005 let dt_rank = infer_dt_rank(first_dt, inner_size, "model.layers.0.mixer.dt_proj.weight")?;
3006
3007 let cfg = Config {
3008 vocab_size,
3009 hidden_size,
3010 num_layers,
3011 inner_size,
3012 state_size,
3013 conv_kernel,
3014 dt_rank,
3015 layer_norm_eps: 1e-5,
3016 };
3017 cfg.validate()?;
3018
3019 let embeddings = tensor_from(emb)?;
3020 let final_norm_w = tensor_from(weights.require("model.norm.weight")?)?;
3021 let final_norm_b = optional_tensor_from(weights, "model.norm.bias")?;
3022 let lm_head = if let Some(t) = weights.get("lm_head.weight") {
3023 tensor_from(t)?
3024 } else {
3025 embeddings.clone()
3026 };
3027 let lm_head_b = optional_tensor_from(weights, "lm_head.bias")?;
3028
3029 let mut layers = Vec::with_capacity(num_layers);
3030 for idx in 0..num_layers {
3031 let root = format!("model.layers.{idx}");
3032 let mixer = format!("{root}.mixer");
3033
3034 let norm_w = tensor_from(weights.require(&format!("{root}.norm.weight"))?)?;
3035 let norm_b = optional_tensor_from(weights, &format!("{root}.norm.bias"))?;
3036
3037 let in_proj_w = tensor_from(weights.require(&format!("{mixer}.in_proj.weight"))?)?;
3038 let in_proj_b = optional_tensor_from(weights, &format!("{mixer}.in_proj.bias"))?;
3039
3040 let conv_w = tensor_from_conv(
3041 weights.require(&format!("{mixer}.conv1d.weight"))?,
3042 inner_size,
3043 )?;
3044 let conv_b = optional_tensor_from(weights, &format!("{mixer}.conv1d.bias"))?;
3045
3046 let x_proj_w = tensor_from(weights.require(&format!("{mixer}.x_proj.weight"))?)?;
3047 let x_proj_b = optional_tensor_from(weights, &format!("{mixer}.x_proj.bias"))?;
3048
3049 let dt_proj_w = tensor_from(weights.require(&format!("{mixer}.dt_proj.weight"))?)?;
3050 let dt_proj_b = tensor_from(weights.require(&format!("{mixer}.dt_proj.bias"))?)?;
3051
3052 let a_log = tensor_from(weights.require(&format!("{mixer}.A_log"))?)?;
3053 let a = a_from_a_log_tensor(&a_log);
3054 let d = tensor_from(weights.require(&format!("{mixer}.D"))?)?;
3055
3056 let out_proj_w = tensor_from(weights.require(&format!("{mixer}.out_proj.weight"))?)?;
3057 let out_proj_b = optional_tensor_from(weights, &format!("{mixer}.out_proj.bias"))?;
3058
3059 validate_layer_shapes(
3060 &cfg,
3061 idx,
3062 &norm_w,
3063 norm_b.as_ref(),
3064 &in_proj_w,
3065 in_proj_b.as_ref(),
3066 &conv_w,
3067 conv_b.as_ref(),
3068 &x_proj_w,
3069 x_proj_b.as_ref(),
3070 &dt_proj_w,
3071 &dt_proj_b,
3072 &a,
3073 &d,
3074 &out_proj_w,
3075 out_proj_b.as_ref(),
3076 )?;
3077
3078 layers.push(LayerWeights {
3079 norm_w,
3080 norm_b,
3081 in_proj_w,
3082 in_proj_b,
3083 conv_w,
3084 conv_b,
3085 x_proj_w,
3086 x_proj_b,
3087 dt_proj_w,
3088 dt_proj_b,
3089 a_log,
3090 a,
3091 d,
3092 out_proj_w,
3093 out_proj_b,
3094 });
3095 }
3096
3097 Ok(Self {
3098 cfg,
3099 embeddings,
3100 final_norm_w,
3101 final_norm_b,
3102 lm_head,
3103 lm_head_b,
3104 layers,
3105 })
3106 }
3107
3108 fn load_official(weights: &Weights) -> Result<Self> {
3109 let emb = weights.require("backbone.embedding.weight")?;
3110 if emb.shape().len() != 2 {
3111 bail!("backbone.embedding.weight must be rank-2");
3112 }
3113 let vocab_size = emb.shape()[0];
3114 let hidden_size = emb.shape()[1];
3115
3116 let num_layers = count_layers(weights, "backbone.layers.", "mixer.in_proj.weight")?;
3117 if num_layers == 0 {
3118 bail!("no Mamba layers found in official checkpoint");
3119 }
3120
3121 let first_in = weights.require("backbone.layers.0.mixer.in_proj.weight")?;
3122 let first_conv = weights.require("backbone.layers.0.mixer.conv1d.weight")?;
3123 let first_a = weights.require("backbone.layers.0.mixer.A_log")?;
3124 let first_dt = weights.require("backbone.layers.0.mixer.dt_proj.weight")?;
3125
3126 let inner_size = infer_in_proj_inner(
3127 first_in,
3128 hidden_size,
3129 "backbone.layers.0.mixer.in_proj.weight",
3130 )?;
3131 let conv_kernel = infer_conv_kernel(
3132 first_conv,
3133 inner_size,
3134 "backbone.layers.0.mixer.conv1d.weight",
3135 )?;
3136 let state_size = infer_state_size(first_a, inner_size, "backbone.layers.0.mixer.A_log")?;
3137 let dt_rank = infer_dt_rank(
3138 first_dt,
3139 inner_size,
3140 "backbone.layers.0.mixer.dt_proj.weight",
3141 )?;
3142
3143 let cfg = Config {
3144 vocab_size,
3145 hidden_size,
3146 num_layers,
3147 inner_size,
3148 state_size,
3149 conv_kernel,
3150 dt_rank,
3151 layer_norm_eps: 1e-5,
3152 };
3153 cfg.validate()?;
3154
3155 let embeddings = tensor_from(emb)?;
3156 let final_norm_w = tensor_from(weights.require("norm_f.weight")?)?;
3157 let final_norm_b = optional_tensor_from(weights, "norm_f.bias")?;
3158 let lm_head = if let Some(t) = weights.get("lm_head.weight") {
3159 tensor_from(t)?
3160 } else {
3161 embeddings.clone()
3162 };
3163 let lm_head_b = optional_tensor_from(weights, "lm_head.bias")?;
3164
3165 let mut layers = Vec::with_capacity(num_layers);
3166 for idx in 0..num_layers {
3167 let root = format!("backbone.layers.{idx}");
3168 let mixer = format!("{root}.mixer");
3169
3170 let norm_w = tensor_from(weights.require(&format!("{root}.norm.weight"))?)?;
3171 let norm_b = optional_tensor_from(weights, &format!("{root}.norm.bias"))?;
3172
3173 let in_proj_w = tensor_from(weights.require(&format!("{mixer}.in_proj.weight"))?)?;
3174 let in_proj_b = optional_tensor_from(weights, &format!("{mixer}.in_proj.bias"))?;
3175
3176 let conv_w = tensor_from_conv(
3177 weights.require(&format!("{mixer}.conv1d.weight"))?,
3178 inner_size,
3179 )?;
3180 let conv_b = optional_tensor_from(weights, &format!("{mixer}.conv1d.bias"))?;
3181
3182 let x_proj_w = tensor_from(weights.require(&format!("{mixer}.x_proj.weight"))?)?;
3183 let x_proj_b = optional_tensor_from(weights, &format!("{mixer}.x_proj.bias"))?;
3184
3185 let dt_proj_w = tensor_from(weights.require(&format!("{mixer}.dt_proj.weight"))?)?;
3186 let dt_proj_b = tensor_from(weights.require(&format!("{mixer}.dt_proj.bias"))?)?;
3187
3188 let a_log = tensor_from(weights.require(&format!("{mixer}.A_log"))?)?;
3189 let a = a_from_a_log_tensor(&a_log);
3190 let d = tensor_from(weights.require(&format!("{mixer}.D"))?)?;
3191
3192 let out_proj_w = tensor_from(weights.require(&format!("{mixer}.out_proj.weight"))?)?;
3193 let out_proj_b = optional_tensor_from(weights, &format!("{mixer}.out_proj.bias"))?;
3194
3195 validate_layer_shapes(
3196 &cfg,
3197 idx,
3198 &norm_w,
3199 norm_b.as_ref(),
3200 &in_proj_w,
3201 in_proj_b.as_ref(),
3202 &conv_w,
3203 conv_b.as_ref(),
3204 &x_proj_w,
3205 x_proj_b.as_ref(),
3206 &dt_proj_w,
3207 &dt_proj_b,
3208 &a,
3209 &d,
3210 &out_proj_w,
3211 out_proj_b.as_ref(),
3212 )?;
3213
3214 layers.push(LayerWeights {
3215 norm_w,
3216 norm_b,
3217 in_proj_w,
3218 in_proj_b,
3219 conv_w,
3220 conv_b,
3221 x_proj_w,
3222 x_proj_b,
3223 dt_proj_w,
3224 dt_proj_b,
3225 a_log,
3226 a,
3227 d,
3228 out_proj_w,
3229 out_proj_b,
3230 });
3231 }
3232
3233 Ok(Self {
3234 cfg,
3235 embeddings,
3236 final_norm_w,
3237 final_norm_b,
3238 lm_head,
3239 lm_head_b,
3240 layers,
3241 })
3242 }
3243}
3244
3245fn tensor_from(t: &WeightTensor) -> Result<Tensor1D> {
3246 Ok(Tensor1D::from_vec(t.data().to_vec()))
3247}
3248
3249fn a_from_a_log_tensor(a_log: &Tensor1D) -> Tensor1D {
3250 let mut out = a_log.as_slice().to_vec();
3251 for v in &mut out {
3252 *v = -v.exp();
3253 }
3254 Tensor1D::from_vec(out)
3255}
3256
3257fn optional_tensor_from(weights: &Weights, name: &str) -> Result<Option<Tensor1D>> {
3258 match weights.get(name) {
3259 Some(t) => Ok(Some(tensor_from(t)?)),
3260 None => Ok(None),
3261 }
3262}
3263
3264fn tensor_from_conv(t: &WeightTensor, inner_size: usize) -> Result<Tensor1D> {
3265 match t.shape() {
3266 [i, _k] if *i == inner_size => Ok(Tensor1D::from_vec(t.data().to_vec())),
3267 [i, one, k] if *i == inner_size && *one == 1 => {
3268 let mut out = Vec::with_capacity(inner_size * k);
3269 let src = t.data();
3270 for ch in 0..inner_size {
3271 let off = ch * k;
3272 out.extend_from_slice(&src[off..off + k]);
3273 }
3274 Ok(Tensor1D::from_vec(out))
3275 }
3276 other => bail!("unexpected conv1d weight shape {:?}", other),
3277 }
3278}
3279
3280fn count_layers(weights: &Weights, prefix: &str, suffix: &str) -> Result<usize> {
3281 let mut max_layer = None::<usize>;
3282 for name in weights.tensor_names() {
3283 let Some(rest) = name.strip_prefix(prefix) else {
3284 continue;
3285 };
3286 let Some((idx_s, tail)) = rest.split_once('.') else {
3287 continue;
3288 };
3289 if tail != suffix {
3290 continue;
3291 }
3292 let idx = idx_s
3293 .parse::<usize>()
3294 .with_context(|| format!("invalid layer index in tensor name '{name}'"))?;
3295 max_layer = Some(max_layer.map_or(idx, |m| m.max(idx)));
3296 }
3297 Ok(max_layer.map_or(0, |m| m + 1))
3298}
3299
3300fn infer_in_proj_inner(t: &WeightTensor, hidden: usize, name: &str) -> Result<usize> {
3301 let shape = t.shape();
3302 if shape.len() != 2 {
3303 bail!("{name} must be rank-2, got {:?}", shape);
3304 }
3305 if shape[1] != hidden {
3306 bail!("{name} expected cols={}, got {}", hidden, shape[1]);
3307 }
3308 if !shape[0].is_multiple_of(2) {
3309 bail!("{name} first dim {} must be 2*d_inner", shape[0]);
3310 }
3311 Ok(shape[0] / 2)
3312}
3313
3314fn infer_conv_kernel(t: &WeightTensor, inner: usize, name: &str) -> Result<usize> {
3315 let shape = t.shape();
3316 match shape {
3317 [i, k] if *i == inner => Ok(*k),
3318 [i, one, k] if *i == inner && *one == 1 => Ok(*k),
3319 _ => bail!("{name} shape {:?} incompatible with d_inner={inner}", shape),
3320 }
3321}
3322
3323fn infer_state_size(t: &WeightTensor, inner: usize, name: &str) -> Result<usize> {
3324 let shape = t.shape();
3325 if shape.len() != 2 {
3326 bail!("{name} must be rank-2, got {:?}", shape);
3327 }
3328 if shape[0] != inner {
3329 bail!("{name} expected rows={}, got {}", inner, shape[0]);
3330 }
3331 Ok(shape[1])
3332}
3333
3334fn infer_dt_rank(t: &WeightTensor, inner: usize, name: &str) -> Result<usize> {
3335 let shape = t.shape();
3336 if shape.len() != 2 {
3337 bail!("{name} must be rank-2, got {:?}", shape);
3338 }
3339 if shape[0] != inner {
3340 bail!("{name} expected rows={}, got {}", inner, shape[0]);
3341 }
3342 Ok(shape[1])
3343}
3344
3345#[allow(clippy::too_many_arguments)]
3346fn validate_layer_shapes(
3347 cfg: &Config,
3348 idx: usize,
3349 norm_w: &Tensor1D,
3350 norm_b: Option<&Tensor1D>,
3351 in_proj_w: &Tensor1D,
3352 in_proj_b: Option<&Tensor1D>,
3353 conv_w: &Tensor1D,
3354 conv_b: Option<&Tensor1D>,
3355 x_proj_w: &Tensor1D,
3356 x_proj_b: Option<&Tensor1D>,
3357 dt_proj_w: &Tensor1D,
3358 dt_proj_b: &Tensor1D,
3359 a: &Tensor1D,
3360 d: &Tensor1D,
3361 out_proj_w: &Tensor1D,
3362 out_proj_b: Option<&Tensor1D>,
3363) -> Result<()> {
3364 let c = cfg.hidden_size;
3365 let i = cfg.inner_size;
3366 let s = cfg.state_size;
3367 let k = cfg.conv_kernel;
3368 let r = cfg.dt_rank;
3369
3370 let check = |cond: bool, msg: String| -> Result<()> {
3371 if cond {
3372 Ok(())
3373 } else {
3374 bail!("layer {idx}: {msg}")
3375 }
3376 };
3377
3378 check(
3379 norm_w.len() == c,
3380 format!("norm.weight len {} != hidden {c}", norm_w.len()),
3381 )?;
3382 if let Some(b) = norm_b {
3383 check(
3384 b.len() == c,
3385 format!("norm.bias len {} != hidden {c}", b.len()),
3386 )?;
3387 }
3388
3389 check(
3390 in_proj_w.len() == (2 * i) * c,
3391 format!("in_proj.weight len {} != {}", in_proj_w.len(), (2 * i) * c),
3392 )?;
3393 if let Some(b) = in_proj_b {
3394 check(
3395 b.len() == 2 * i,
3396 format!("in_proj.bias len {} != {}", b.len(), 2 * i),
3397 )?;
3398 }
3399
3400 check(
3401 conv_w.len() == i * k,
3402 format!("conv1d.weight len {} != {}", conv_w.len(), i * k),
3403 )?;
3404 if let Some(b) = conv_b {
3405 check(b.len() == i, format!("conv1d.bias len {} != {i}", b.len()))?;
3406 }
3407
3408 check(
3409 x_proj_w.len() == (r + 2 * s) * i,
3410 format!(
3411 "x_proj.weight len {} != {}",
3412 x_proj_w.len(),
3413 (r + 2 * s) * i
3414 ),
3415 )?;
3416 if let Some(b) = x_proj_b {
3417 check(
3418 b.len() == r + 2 * s,
3419 format!("x_proj.bias len {} != {}", b.len(), r + 2 * s),
3420 )?;
3421 }
3422
3423 check(
3424 dt_proj_w.len() == i * r,
3425 format!("dt_proj.weight len {} != {}", dt_proj_w.len(), i * r),
3426 )?;
3427 check(
3428 dt_proj_b.len() == i,
3429 format!("dt_proj.bias len {} != {i}", dt_proj_b.len()),
3430 )?;
3431
3432 check(a.len() == i * s, format!("A len {} != {}", a.len(), i * s))?;
3433 check(d.len() == i, format!("D len {} != {i}", d.len()))?;
3434
3435 check(
3436 out_proj_w.len() == c * i,
3437 format!("out_proj.weight len {} != {}", out_proj_w.len(), c * i),
3438 )?;
3439 if let Some(b) = out_proj_b {
3440 check(
3441 b.len() == c,
3442 format!("out_proj.bias len {} != {c}", b.len()),
3443 )?;
3444 }
3445
3446 Ok(())
3447}
3448
3449fn rms_norm(input: &[f32], weight: &[f32], bias: Option<&[f32]>, eps: f32, out: &mut [f32]) {
3450 debug_assert_eq!(input.len(), weight.len());
3451 debug_assert_eq!(input.len(), out.len());
3452 if let Some(b) = bias {
3453 debug_assert_eq!(b.len(), input.len());
3454 }
3455
3456 let mut mean_sq = 0.0f32;
3457 for &x in input {
3458 mean_sq += x * x;
3459 }
3460 mean_sq /= input.len().max(1) as f32;
3461 let inv = (mean_sq + eps).sqrt().recip();
3462
3463 if let Some(b) = bias {
3464 for idx in 0..input.len() {
3465 out[idx] = input[idx] * inv * weight[idx] + b[idx];
3466 }
3467 } else {
3468 for idx in 0..input.len() {
3469 out[idx] = input[idx] * inv * weight[idx];
3470 }
3471 }
3472}
3473
3474fn rms_norm_backward(
3475 input: &[f32],
3476 weight: &[f32],
3477 grad_out: &[f32],
3478 eps: f32,
3479 grad_input: &mut [f32],
3480 grad_weight: &mut [f32],
3481) {
3482 debug_assert_eq!(input.len(), weight.len());
3483 debug_assert_eq!(input.len(), grad_out.len());
3484 debug_assert_eq!(input.len(), grad_input.len());
3485 debug_assert_eq!(input.len(), grad_weight.len());
3486
3487 let n = input.len().max(1) as f32;
3488 let mut mean_sq = 0.0f32;
3489 for &x in input {
3490 mean_sq += x * x;
3491 }
3492 mean_sq /= n;
3493 let inv = (mean_sq + eps).sqrt().recip();
3494
3495 let mut s = 0.0f32;
3496 for idx in 0..input.len() {
3497 let gw = grad_out[idx] * weight[idx];
3498 grad_weight[idx] = grad_out[idx] * input[idx] * inv;
3499 s += gw * input[idx];
3500 }
3501 let coeff = -s * inv * inv * inv / n;
3502 for idx in 0..input.len() {
3503 grad_input[idx] = grad_out[idx] * weight[idx] * inv + input[idx] * coeff;
3504 }
3505}
3506
3507#[inline(always)]
3508fn add_vec_grad(dst: &mut [f32], src: &[f32]) {
3509 let n = dst.len().min(src.len());
3510 for idx in 0..n {
3511 dst[idx] += src[idx];
3512 }
3513}
3514
3515#[inline(always)]
3516fn sgd_vec_update(param: &mut [f32], grad: &[f32], lr: f32, clip: f32) {
3517 let n = param.len().min(grad.len());
3518 if clip > 0.0 {
3519 for idx in 0..n {
3520 param[idx] += lr * grad[idx].clamp(-clip, clip);
3521 }
3522 } else {
3523 for idx in 0..n {
3524 param[idx] += lr * grad[idx];
3525 }
3526 }
3527}
3528
3529#[allow(clippy::needless_range_loop)]
3530#[inline(always)]
3531fn add_outer_grad(dst: &mut [f32], rows: usize, cols: usize, left: &[f32], right: &[f32]) {
3532 let rows = rows.min(left.len());
3533 let cols = cols.min(right.len());
3534 let n = dst.len();
3535 for row in 0..rows {
3536 let off = row * cols;
3537 if off >= n {
3538 break;
3539 }
3540 let limit = (n - off).min(cols);
3541 let g = left[row];
3542 for col in 0..limit {
3543 dst[off + col] += g * right[col];
3544 }
3545 }
3546}
3547
3548#[inline(always)]
3549fn apply_adam_vec_update_raw(
3550 param: &mut [f32],
3551 grad: &[f32],
3552 moment_m: &mut [f32],
3553 moment_v: &mut [f32],
3554 step: &AdamStep,
3555) {
3556 let n = param
3557 .len()
3558 .min(grad.len())
3559 .min(moment_m.len())
3560 .min(moment_v.len());
3561 if n == 0 {
3562 return;
3563 }
3564 let b1 = step.b1;
3565 let b2 = step.b2;
3566 let one_m_b1 = 1.0 - b1;
3567 let one_m_b2 = 1.0 - b2;
3568 let lr = step.lr;
3569 let eps = step.eps;
3570 let inv_bc1 = 1.0 / step.bias_corr1;
3571 let inv_bc2 = 1.0 / step.bias_corr2;
3572 if step.clip > 0.0 {
3573 let clip = step.clip;
3574 for idx in 0..n {
3575 let g = grad[idx].clamp(-clip, clip);
3576 let m = b1 * moment_m[idx] + one_m_b1 * g;
3577 let v = b2 * moment_v[idx] + one_m_b2 * g * g;
3578 moment_m[idx] = m;
3579 moment_v[idx] = v;
3580 let m_hat = m * inv_bc1;
3581 let v_hat = v * inv_bc2;
3582 param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3583 }
3584 } else {
3585 for idx in 0..n {
3586 let g = grad[idx];
3587 let m = b1 * moment_m[idx] + one_m_b1 * g;
3588 let v = b2 * moment_v[idx] + one_m_b2 * g * g;
3589 moment_m[idx] = m;
3590 moment_v[idx] = v;
3591 let m_hat = m * inv_bc1;
3592 let v_hat = v * inv_bc2;
3593 param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3594 }
3595 }
3596}
3597
3598#[inline(always)]
3599fn apply_adam_vec_update(
3600 param: &mut [f32],
3601 grad: &[f32],
3602 adam: &mut AdamTensorState,
3603 step: &AdamStep,
3604) {
3605 let n = param
3606 .len()
3607 .min(grad.len())
3608 .min(adam.m.len())
3609 .min(adam.v.len());
3610 if n == 0 {
3611 return;
3612 }
3613 let b1 = step.b1;
3614 let b2 = step.b2;
3615 let one_m_b1 = 1.0 - b1;
3616 let one_m_b2 = 1.0 - b2;
3617 let lr = step.lr;
3618 let eps = step.eps;
3619 let inv_bc1 = 1.0 / step.bias_corr1;
3620 let inv_bc2 = 1.0 / step.bias_corr2;
3621 let do_clip = step.clip > 0.0;
3622 let clip = step.clip;
3623 let m = adam.m.as_mut_slice();
3624 let v = adam.v.as_mut_slice();
3625 if do_clip {
3626 for idx in 0..n {
3627 let g = grad[idx].clamp(-clip, clip);
3628 let mm = b1 * m[idx] + one_m_b1 * g;
3629 let vv = b2 * v[idx] + one_m_b2 * g * g;
3630 m[idx] = mm;
3631 v[idx] = vv;
3632 let m_hat = mm * inv_bc1;
3633 let v_hat = vv * inv_bc2;
3634 param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3635 }
3636 } else {
3637 let mut idx = 0usize;
3638 unsafe {
3639 let b1v = f32x8::splat(b1);
3640 let b2v = f32x8::splat(b2);
3641 let one_b1v = f32x8::splat(one_m_b1);
3642 let one_b2v = f32x8::splat(one_m_b2);
3643 let inv_bc1v = f32x8::splat(inv_bc1);
3644 let inv_bc2v = f32x8::splat(inv_bc2);
3645 let lrv = f32x8::splat(lr);
3646 let epsv = f32x8::splat(eps);
3647 while idx + 8 <= n {
3648 let gv = grad.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3649 let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3650 let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3651 let mm = mv * b1v + gv * one_b1v;
3652 let vv2 = vv * b2v + (gv * gv) * one_b2v;
3653 m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
3654 v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
3655
3656 let pv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3657 let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
3658 param
3659 .as_mut_ptr()
3660 .add(idx)
3661 .cast::<f32x8>()
3662 .write_unaligned(pv + upd);
3663 idx += 8;
3664 }
3665 }
3666 while idx < n {
3667 let g = grad[idx];
3668 let mm = b1 * m[idx] + one_m_b1 * g;
3669 let vv = b2 * v[idx] + one_m_b2 * g * g;
3670 m[idx] = mm;
3671 v[idx] = vv;
3672 let m_hat = mm * inv_bc1;
3673 let v_hat = vv * inv_bc2;
3674 param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3675 idx += 1;
3676 }
3677 }
3678}
3679
3680#[inline(always)]
3681fn apply_adam_vec_update_and_sync_neg_exp(
3682 param_log: &mut [f32],
3683 param_value: &mut [f32],
3684 grad: &[f32],
3685 adam: &mut AdamTensorState,
3686 step: &AdamStep,
3687) {
3688 let n = param_log
3689 .len()
3690 .min(param_value.len())
3691 .min(grad.len())
3692 .min(adam.m.len())
3693 .min(adam.v.len());
3694 if n == 0 {
3695 return;
3696 }
3697 let b1 = step.b1;
3698 let b2 = step.b2;
3699 let one_m_b1 = 1.0 - b1;
3700 let one_m_b2 = 1.0 - b2;
3701 let lr = step.lr;
3702 let eps = step.eps;
3703 let inv_bc1 = 1.0 / step.bias_corr1;
3704 let inv_bc2 = 1.0 / step.bias_corr2;
3705 let do_clip = step.clip > 0.0;
3706 let clip = step.clip;
3707 let m = adam.m.as_mut_slice();
3708 let v = adam.v.as_mut_slice();
3709 if do_clip {
3710 for idx in 0..n {
3711 let g = grad[idx].clamp(-clip, clip);
3712 let mm = b1 * m[idx] + one_m_b1 * g;
3713 let vv = b2 * v[idx] + one_m_b2 * g * g;
3714 m[idx] = mm;
3715 v[idx] = vv;
3716 let m_hat = mm * inv_bc1;
3717 let v_hat = vv * inv_bc2;
3718 let new_log = param_log[idx] + lr * m_hat / (v_hat.sqrt() + eps);
3719 param_log[idx] = new_log;
3720 param_value[idx] = -new_log.exp();
3721 }
3722 return;
3723 }
3724
3725 let mut idx = 0usize;
3726 unsafe {
3727 let b1v = f32x8::splat(b1);
3728 let b2v = f32x8::splat(b2);
3729 let one_b1v = f32x8::splat(one_m_b1);
3730 let one_b2v = f32x8::splat(one_m_b2);
3731 let inv_bc1v = f32x8::splat(inv_bc1);
3732 let inv_bc2v = f32x8::splat(inv_bc2);
3733 let lrv = f32x8::splat(lr);
3734 let epsv = f32x8::splat(eps);
3735 while idx + 8 <= n {
3736 let gv = grad.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3737 let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3738 let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3739 let mm = mv * b1v + gv * one_b1v;
3740 let vv2 = vv * b2v + (gv * gv) * one_b2v;
3741 m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
3742 v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
3743
3744 let pv = param_log.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3745 let new_log = pv + ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
3746 param_log
3747 .as_mut_ptr()
3748 .add(idx)
3749 .cast::<f32x8>()
3750 .write_unaligned(new_log);
3751 let lanes = new_log.to_array();
3752 for (lane, value) in lanes.iter().enumerate() {
3753 param_value[idx + lane] = -value.exp();
3754 }
3755 idx += 8;
3756 }
3757 }
3758 while idx < n {
3759 let g = grad[idx];
3760 let mm = b1 * m[idx] + one_m_b1 * g;
3761 let vv = b2 * v[idx] + one_m_b2 * g * g;
3762 m[idx] = mm;
3763 v[idx] = vv;
3764 let m_hat = mm * inv_bc1;
3765 let v_hat = vv * inv_bc2;
3766 let new_log = param_log[idx] + lr * m_hat / (v_hat.sqrt() + eps);
3767 param_log[idx] = new_log;
3768 param_value[idx] = -new_log.exp();
3769 idx += 1;
3770 }
3771}
3772
3773#[inline(always)]
3774#[allow(clippy::needless_range_loop)]
3775fn apply_adam_outer_update(
3776 param: &mut [f32],
3777 rows: usize,
3778 cols: usize,
3779 left: &[f32],
3780 right: &[f32],
3781 adam: &mut AdamTensorState,
3782 step: &AdamStep,
3783) {
3784 let rows = rows.min(left.len());
3785 let cols = cols.min(right.len());
3786 let n = param.len().min(adam.m.len()).min(adam.v.len());
3787 if rows == 0 || cols == 0 || n == 0 {
3788 return;
3789 }
3790 let b1 = step.b1;
3791 let b2 = step.b2;
3792 let one_m_b1 = 1.0 - b1;
3793 let one_m_b2 = 1.0 - b2;
3794 let lr = step.lr;
3795 let eps = step.eps;
3796 let inv_bc1 = 1.0 / step.bias_corr1;
3797 let inv_bc2 = 1.0 / step.bias_corr2;
3798 let do_clip = step.clip > 0.0;
3799 let clip = step.clip;
3800 let m = adam.m.as_mut_slice();
3801 let v = adam.v.as_mut_slice();
3802 let b1v = f32x8::splat(b1);
3803 let b2v = f32x8::splat(b2);
3804 let one_b1v = f32x8::splat(one_m_b1);
3805 let one_b2v = f32x8::splat(one_m_b2);
3806 let inv_bc1v = f32x8::splat(inv_bc1);
3807 let inv_bc2v = f32x8::splat(inv_bc2);
3808 let epsv = f32x8::splat(eps);
3809 let lrv = f32x8::splat(lr);
3810 for row in 0..rows {
3811 let g_row = left[row];
3812 let off = row * cols;
3813 if off >= n {
3814 break;
3815 }
3816 let row_cols = (n - off).min(cols);
3817 if do_clip {
3818 for col in 0..row_cols {
3819 let idx = off + col;
3820 let g = (g_row * right[col]).clamp(-clip, clip);
3821 let mm = b1 * m[idx] + one_m_b1 * g;
3822 let vv = b2 * v[idx] + one_m_b2 * g * g;
3823 m[idx] = mm;
3824 v[idx] = vv;
3825 let m_hat = mm * inv_bc1;
3826 let v_hat = vv * inv_bc2;
3827 param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3828 }
3829 } else {
3830 let mut col = 0usize;
3831 unsafe {
3832 let g8 = f32x8::splat(g_row);
3833 while col + 8 <= row_cols {
3834 let idx = off + col;
3835 let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
3836 let gv = g8 * rv;
3837
3838 let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3839 let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3840 let mm = mv * b1v + gv * one_b1v;
3841 let vv2 = vv * b2v + (gv * gv) * one_b2v;
3842 m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
3843 v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
3844
3845 let pv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
3846 let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
3847 param
3848 .as_mut_ptr()
3849 .add(idx)
3850 .cast::<f32x8>()
3851 .write_unaligned(pv + upd);
3852 col += 8;
3853 }
3854 }
3855 while col < row_cols {
3856 let idx = off + col;
3857 let g = g_row * right[col];
3858 let mm = b1 * m[idx] + one_m_b1 * g;
3859 let vv = b2 * v[idx] + one_m_b2 * g * g;
3860 m[idx] = mm;
3861 v[idx] = vv;
3862 let m_hat = mm * inv_bc1;
3863 let v_hat = vv * inv_bc2;
3864 param[idx] += lr * m_hat / (v_hat.sqrt() + eps);
3865 col += 1;
3866 }
3867 }
3868 }
3869}
3870
3871fn depthwise_conv_step(
3872 x: &[f32],
3873 conv_w: &Tensor1D,
3874 conv_b: Option<&Tensor1D>,
3875 conv_kernel: usize,
3876 state: &mut LayerState,
3877 out: &mut [f32],
3878) {
3879 if conv_kernel == 4 {
3880 depthwise_conv_step_k4(x, conv_w, conv_b, state, out);
3881 return;
3882 }
3883 let inner = x.len();
3884 debug_assert_eq!(out.len(), inner);
3885 debug_assert_eq!(conv_w.len(), inner * conv_kernel);
3886
3887 let pos = state.conv_pos;
3888 let conv_state = state.conv.as_mut_slice();
3889 let weight = conv_w.as_slice();
3890
3891 for ch in 0..inner {
3892 let base = ch * conv_kernel;
3893 conv_state[base + pos] = x[ch];
3894
3895 let mut acc = conv_b.as_ref().map_or(0.0, |b| b[ch]);
3896 let mut ring_idx = pos;
3897 for tap in 0..conv_kernel {
3898 acc += conv_state[base + ring_idx] * weight[base + tap];
3899 ring_idx = if ring_idx == 0 {
3900 conv_kernel - 1
3901 } else {
3902 ring_idx - 1
3903 };
3904 }
3905 out[ch] = acc;
3906 }
3907
3908 state.conv_pos = if pos + 1 == conv_kernel { 0 } else { pos + 1 };
3909}
3910
3911#[inline(always)]
3912fn depthwise_conv_step_k4(
3913 x: &[f32],
3914 conv_w: &Tensor1D,
3915 conv_b: Option<&Tensor1D>,
3916 state: &mut LayerState,
3917 out: &mut [f32],
3918) {
3919 let inner = x.len();
3920 debug_assert_eq!(out.len(), inner);
3921 debug_assert_eq!(conv_w.len(), inner * 4);
3922
3923 let pos = state.conv_pos;
3924 let conv_state = state.conv.as_mut_slice();
3925 let weight = conv_w.as_slice();
3926
3927 for ch in 0..inner {
3928 let base = ch * 4;
3929 conv_state[base + pos] = x[ch];
3930 let acc = match pos {
3931 0 => {
3932 conv_state[base] * weight[base]
3933 + conv_state[base + 3] * weight[base + 1]
3934 + conv_state[base + 2] * weight[base + 2]
3935 + conv_state[base + 1] * weight[base + 3]
3936 }
3937 1 => {
3938 conv_state[base + 1] * weight[base]
3939 + conv_state[base] * weight[base + 1]
3940 + conv_state[base + 3] * weight[base + 2]
3941 + conv_state[base + 2] * weight[base + 3]
3942 }
3943 2 => {
3944 conv_state[base + 2] * weight[base]
3945 + conv_state[base + 1] * weight[base + 1]
3946 + conv_state[base] * weight[base + 2]
3947 + conv_state[base + 3] * weight[base + 3]
3948 }
3949 _ => {
3950 conv_state[base + 3] * weight[base]
3951 + conv_state[base + 2] * weight[base + 1]
3952 + conv_state[base + 1] * weight[base + 2]
3953 + conv_state[base] * weight[base + 3]
3954 }
3955 };
3956 out[ch] = acc + conv_b.as_ref().map_or(0.0, |b| b[ch]);
3957 }
3958
3959 state.conv_pos = (pos + 1) & 3;
3960}
3961
3962#[inline(always)]
3963unsafe fn selective_scan_state16<const CAPTURE: bool>(
3964 row_a: *const f32,
3965 row_ssm: *mut f32,
3966 dt: f32,
3967 x_dt: f32,
3968 b_ptr: *const f32,
3969 c_ptr: *const f32,
3970 trace_d_a: *mut f32,
3971) -> f32 {
3972 let mut y = 0.0f32;
3973 let mut j = 0usize;
3974 while j < 16 {
3975 let prev = *row_ssm.add(j);
3976 let d_a = (dt * *row_a.add(j)).exp();
3977 if CAPTURE {
3978 *trace_d_a.add(j) = d_a;
3979 }
3980 let next = prev * d_a + x_dt * *b_ptr.add(j);
3981 *row_ssm.add(j) = next;
3982 y += next * *c_ptr.add(j);
3983 j += 1;
3984 }
3985 y
3986}
3987
3988#[inline(always)]
3989fn silu(x: f32) -> f32 {
3990 x / (1.0 + (-x).exp())
3991}
3992
3993#[inline(always)]
3994fn sigmoid(x: f32) -> f32 {
3995 1.0 / (1.0 + (-x).exp())
3996}
3997
3998#[inline(always)]
3999fn silu_with_sigmoid(x: f32) -> (f32, f32) {
4000 let denom = 1.0 + (-x).exp();
4001 (x / denom, 1.0 / denom)
4002}
4003
4004#[inline(always)]
4005fn silu_grad_from_sigmoid(x: f32, s: f32) -> f32 {
4006 s * (1.0 + x * (1.0 - s))
4007}
4008
4009#[inline(always)]
4010fn softplus(x: f32) -> f32 {
4011 if x > 20.0 { x } else { (1.0 + x.exp()).ln() }
4012}
4013
4014struct MambaRng {
4015 state: u64,
4016}
4017
4018impl MambaRng {
4019 fn new(seed: u64) -> Self {
4020 Self {
4021 state: seed ^ 0x9E37_79B9_7F4A_7C15,
4022 }
4023 }
4024
4025 #[inline]
4026 fn next_u32(&mut self) -> u32 {
4027 self.state = self
4028 .state
4029 .wrapping_mul(6_364_136_223_846_793_005)
4030 .wrapping_add(1);
4031 (self.state >> 32) as u32
4032 }
4033
4034 #[inline]
4035 fn next_f32(&mut self) -> f32 {
4036 let v = self.next_u32() as f32;
4037 v * (1.0 / (u32::MAX as f32))
4038 }
4039}
4040
4041#[inline]
4042fn init_uniform(t: &mut Tensor1D, rng: &mut MambaRng, scale: f32) {
4043 for v in t.as_mut_slice() {
4044 let r = rng.next_f32() - 0.5;
4045 *v = r * 2.0 * scale;
4046 }
4047}
4048
4049#[inline]
4050fn init_const(t: &mut Tensor1D, value: f32) {
4051 t.as_mut_slice().fill(value);
4052}
4053
4054#[cfg(test)]
4055mod tests {
4056 use super::*;
4057
4058 fn target_log_prob(model: &Model, token: u32, target: u8) -> f32 {
4059 let mut state = model.new_state();
4060 let mut scratch = ScratchBuffers::new(model.config());
4061 let logits = model.forward(&mut scratch, token, &mut state);
4062 let mut max_logit = f32::NEG_INFINITY;
4063 for &logit in logits {
4064 max_logit = max_logit.max(logit);
4065 }
4066 let mut denom = 0.0f64;
4067 for &logit in logits {
4068 denom += ((logit - max_logit) as f64).exp();
4069 }
4070 let p = ((logits[target as usize] - max_logit) as f64).exp() / denom;
4071 p.max(1e-30).ln() as f32
4072 }
4073
4074 #[test]
4075 fn forward_is_deterministic_for_same_input_and_state() {
4076 let cfg = Config {
4077 vocab_size: 256,
4078 hidden_size: 64,
4079 num_layers: 2,
4080 inner_size: 96,
4081 state_size: 8,
4082 conv_kernel: 4,
4083 dt_rank: 8,
4084 layer_norm_eps: 1e-5,
4085 };
4086 let model = Model::new_random(cfg.clone(), 1234).expect("random model");
4087 let mut s1 = model.new_state();
4088 let mut s2 = model.new_state();
4089 let mut b1 = ScratchBuffers::new(&cfg);
4090 let mut b2 = ScratchBuffers::new(&cfg);
4091
4092 let seq = b"deterministic mamba";
4093 for &tok in seq {
4094 let l1 = model.forward(&mut b1, tok as u32, &mut s1).to_vec();
4095 let l2 = model.forward(&mut b2, tok as u32, &mut s2).to_vec();
4096 assert_eq!(l1.len(), l2.len());
4097 for (a, b) in l1.iter().zip(l2.iter()) {
4098 assert_eq!(a.to_bits(), b.to_bits());
4099 }
4100 }
4101 }
4102
4103 #[test]
4104 fn traced_and_untraced_forward_match_exactly() {
4105 let cfg = Config {
4106 vocab_size: 256,
4107 hidden_size: 64,
4108 num_layers: 2,
4109 inner_size: 96,
4110 state_size: 8,
4111 conv_kernel: 4,
4112 dt_rank: 8,
4113 layer_norm_eps: 1e-5,
4114 };
4115 let model = Model::new_random(cfg.clone(), 4321).expect("random model");
4116 let mut traced_state = model.new_state();
4117 let mut plain_state = model.new_state();
4118 let mut traced_scratch = ScratchBuffers::new(&cfg);
4119 let mut plain_scratch = ScratchBuffers::new(&cfg);
4120 traced_scratch.set_capture_train_trace(true);
4121 plain_scratch.set_capture_train_trace(false);
4122
4123 let seq = b"trace equivalence for mamba";
4124 for &tok in seq {
4125 let traced_logits = model
4126 .forward(&mut traced_scratch, tok as u32, &mut traced_state)
4127 .to_vec();
4128 let plain_logits = model
4129 .forward(&mut plain_scratch, tok as u32, &mut plain_state)
4130 .to_vec();
4131 for (a, b) in traced_logits.iter().zip(plain_logits.iter()) {
4132 assert_eq!(a.to_bits(), b.to_bits());
4133 }
4134 for (tr_layer, plain_layer) in traced_state.layers.iter().zip(plain_state.layers.iter())
4135 {
4136 for (&a, &b) in tr_layer
4137 .conv
4138 .as_slice()
4139 .iter()
4140 .zip(plain_layer.conv.as_slice())
4141 {
4142 assert_eq!(a.to_bits(), b.to_bits());
4143 }
4144 for (&a, &b) in tr_layer
4145 .ssm
4146 .as_slice()
4147 .iter()
4148 .zip(plain_layer.ssm.as_slice())
4149 {
4150 assert_eq!(a.to_bits(), b.to_bits());
4151 }
4152 assert_eq!(tr_layer.conv_pos, plain_layer.conv_pos);
4153 }
4154 }
4155 }
4156
4157 #[test]
4158 fn online_embed_gradient_matches_finite_difference() {
4159 let cfg = Config {
4160 vocab_size: 256,
4161 hidden_size: 16,
4162 num_layers: 2,
4163 inner_size: 24,
4164 state_size: 4,
4165 conv_kernel: 3,
4166 dt_rank: 4,
4167 layer_norm_eps: 1e-5,
4168 };
4169 let token = 7u32;
4170 let target = 19u8;
4171 let lr = 1e-3f32;
4172 let eps = 1e-3f32;
4173
4174 let model = Model::new_random(cfg.clone(), 99).expect("random model");
4175 let mut state = model.new_state();
4176 let mut scratch = ScratchBuffers::new(&cfg);
4177 scratch.set_capture_train_trace(true);
4178 let logits = model.forward(&mut scratch, token, &mut state);
4179
4180 let mut pdf = vec![0.0f64; cfg.vocab_size];
4181 let mut max_logit = f32::NEG_INFINITY;
4182 for &logit in logits {
4183 max_logit = max_logit.max(logit);
4184 }
4185 let mut denom = 0.0f64;
4186 for &logit in logits {
4187 denom += ((logit - max_logit) as f64).exp();
4188 }
4189 for (idx, out) in pdf.iter_mut().enumerate() {
4190 *out = ((logits[idx] - max_logit) as f64).exp() / denom;
4191 }
4192
4193 let base = model.clone();
4194 let mut trained = base.clone();
4195 let mut train_scratch = scratch.clone();
4196 trained
4197 .online_train_step_bptt1(
4198 &mut train_scratch,
4199 &state,
4200 target,
4201 &pdf,
4202 TrainScopeMask {
4203 embed: true,
4204 ..TrainScopeMask::default()
4205 },
4206 OptimizerKind::Sgd,
4207 lr,
4208 0.0,
4209 &mut 0usize,
4210 None,
4211 None,
4212 None,
4213 None,
4214 )
4215 .expect("training step");
4216
4217 let param_idx = token as usize * cfg.hidden_size;
4218 let analytic = (trained.embeddings[param_idx] - base.embeddings[param_idx]) / lr;
4219
4220 let mut plus = base.clone();
4221 plus.embeddings[param_idx] += eps;
4222 let mut minus = base.clone();
4223 minus.embeddings[param_idx] -= eps;
4224 let numeric = (target_log_prob(&plus, token, target)
4225 - target_log_prob(&minus, token, target))
4226 / (2.0 * eps);
4227
4228 let diff = (analytic - numeric).abs();
4229 let scale = analytic.abs().max(numeric.abs()).max(1.0);
4230 assert!(
4231 diff <= 2e-2 * scale,
4232 "analytic={analytic} numeric={numeric} diff={diff}"
4233 );
4234 }
4235
4236 fn test_cfg() -> Config {
4237 Config {
4238 vocab_size: 256,
4239 hidden_size: 32,
4240 num_layers: 1,
4241 inner_size: 48,
4242 state_size: 6,
4243 conv_kernel: 3,
4244 dt_rank: 6,
4245 layer_norm_eps: 1e-5,
4246 }
4247 }
4248
4249 fn softmax_loss(logits: &[f32], target: u8) -> f64 {
4250 let max_logit = logits
4251 .iter()
4252 .copied()
4253 .fold(f32::NEG_INFINITY, |a, b| a.max(b));
4254 let mut denom = 0.0f64;
4255 for &z in logits {
4256 denom += ((z - max_logit) as f64).exp();
4257 }
4258 let p = ((logits[target as usize] - max_logit) as f64).exp() / denom.max(1e-300);
4259 -p.max(1e-300).ln()
4260 }
4261
4262 fn softmax_pdf(logits: &[f32]) -> Vec<f64> {
4263 let mut pdf = vec![0.0f64; logits.len()];
4264 let max_logit = logits
4265 .iter()
4266 .copied()
4267 .fold(f32::NEG_INFINITY, |a, b| a.max(b));
4268 let mut denom = 0.0f64;
4269 for &z in logits {
4270 denom += ((z - max_logit) as f64).exp();
4271 }
4272 let inv = 1.0 / denom.max(1e-300);
4273 for (idx, out) in pdf.iter_mut().enumerate() {
4274 *out = ((logits[idx] - max_logit) as f64).exp() * inv;
4275 }
4276 pdf
4277 }
4278
4279 fn segment_loss(model: &Model, cfg: &Config, steps: &[(u32, u8)]) -> f64 {
4280 if steps.is_empty() {
4281 return 0.0;
4282 }
4283 let mut scratch = ScratchBuffers::new(cfg);
4284 let mut state = model.new_state();
4285 let mut loss = 0.0f64;
4286 for &(input, target) in steps {
4287 let logits = model.forward(&mut scratch, input, &mut state);
4288 loss += softmax_loss(logits, target);
4289 }
4290 loss / (steps.len() as f64)
4291 }
4292
4293 fn segment_grads(model: &Model, cfg: &Config, steps: &[(u32, u8)]) -> FullGradState {
4294 let mut scratch = ScratchBuffers::new(cfg);
4295 let mut state = model.new_state();
4296 let mut states = Vec::with_capacity(steps.len() + 1);
4297 let mut traces = Vec::with_capacity(steps.len());
4298 let mut pdfs = Vec::with_capacity(steps.len());
4299 states.push(state.clone());
4300 for &(input, _) in steps {
4301 scratch.set_capture_train_trace(true);
4302 let logits = model.forward(&mut scratch, input, &mut state);
4303 pdfs.push(softmax_pdf(logits));
4304 traces.push(TokenTrainTrace::from_scratch(&scratch));
4305 states.push(state.clone());
4306 }
4307 let mut grads = model.new_full_grad_state();
4308 let mut recurrent = model.new_recurrent_grad_state();
4309 recurrent.zero();
4310 let scope = TrainScopeMask {
4311 embed: true,
4312 layer_norm: true,
4313 mixer_conv: true,
4314 mixer_ssm: true,
4315 mixer_proj: true,
4316 head: true,
4317 bias: false,
4318 };
4319 let grad_scale = 1.0f32 / (steps.len() as f32);
4320 for idx in (0..steps.len()).rev() {
4321 model
4322 .accumulate_token_step_gradients(
4323 &mut scratch,
4324 &traces[idx],
4325 &states[idx + 1],
4326 steps[idx].1,
4327 &pdfs[idx],
4328 grad_scale,
4329 scope,
4330 &mut grads,
4331 None,
4332 &mut recurrent,
4333 )
4334 .expect("segment gradient accumulation");
4335 }
4336 grads
4337 }
4338
4339 #[derive(Clone, Copy, Debug)]
4340 enum Probe {
4341 Embed,
4342 FinalNormW,
4343 LayerNormW,
4344 InProjW,
4345 ConvW,
4346 SsmA,
4347 OutProjW,
4348 LmHead,
4349 }
4350
4351 fn probe_value(model: &Model, probe: Probe) -> f32 {
4352 match probe {
4353 Probe::Embed => model.embeddings[7],
4354 Probe::FinalNormW => model.final_norm_w[5],
4355 Probe::LayerNormW => model.layers[0].norm_w[9],
4356 Probe::InProjW => model.layers[0].in_proj_w[13],
4357 Probe::ConvW => model.layers[0].conv_w[4],
4358 Probe::SsmA => model.layers[0].a[11],
4359 Probe::OutProjW => model.layers[0].out_proj_w[17],
4360 Probe::LmHead => model.lm_head[23],
4361 }
4362 }
4363
4364 fn set_probe(model: &mut Model, probe: Probe, value: f32) {
4365 match probe {
4366 Probe::Embed => model.embeddings[7] = value,
4367 Probe::FinalNormW => model.final_norm_w[5] = value,
4368 Probe::LayerNormW => model.layers[0].norm_w[9] = value,
4369 Probe::InProjW => model.layers[0].in_proj_w[13] = value,
4370 Probe::ConvW => model.layers[0].conv_w[4] = value,
4371 Probe::SsmA => model.layers[0].a[11] = value,
4372 Probe::OutProjW => model.layers[0].out_proj_w[17] = value,
4373 Probe::LmHead => model.lm_head[23] = value,
4374 }
4375 }
4376
4377 fn probe_grad(grads: &FullGradState, probe: Probe) -> f32 {
4378 match probe {
4379 Probe::Embed => grads.embeddings[7],
4380 Probe::FinalNormW => grads.final_norm_w[5],
4381 Probe::LayerNormW => grads.layers[0].norm_w[9],
4382 Probe::InProjW => grads.layers[0].in_proj_w[13],
4383 Probe::ConvW => grads.layers[0].conv_w[4],
4384 Probe::SsmA => grads.layers[0].a[11],
4385 Probe::OutProjW => grads.layers[0].out_proj_w[17],
4386 Probe::LmHead => grads.lm_head[23],
4387 }
4388 }
4389
4390 #[test]
4391 fn tbptt_segment_gradients_match_finite_difference() {
4392 let cfg = test_cfg();
4393 cfg.validate().expect("valid test config");
4394 let model = Model::new_random(cfg.clone(), 0xD00D_F00D).expect("random model");
4395 let steps = [(0u32, 1u8), (1, 2), (2, 3)];
4396 let grads = segment_grads(&model, &cfg, &steps);
4397 let eps = 1e-3f32;
4398
4399 for probe in [
4400 Probe::Embed,
4401 Probe::FinalNormW,
4402 Probe::LayerNormW,
4403 Probe::InProjW,
4404 Probe::ConvW,
4405 Probe::SsmA,
4406 Probe::OutProjW,
4407 Probe::LmHead,
4408 ] {
4409 let analytic = probe_grad(&grads, probe);
4410
4411 let mut plus = model.clone();
4412 let base = probe_value(&plus, probe);
4413 set_probe(&mut plus, probe, base + eps);
4414 let loss_plus = segment_loss(&plus, &cfg, &steps);
4415
4416 let mut minus = model.clone();
4417 set_probe(&mut minus, probe, base - eps);
4418 let loss_minus = segment_loss(&minus, &cfg, &steps);
4419
4420 let numeric = -((loss_plus - loss_minus) / (2.0 * eps as f64)) as f32;
4421 let tol = 6e-2f32.max(analytic.abs().max(numeric.abs()) * 1e-1);
4422 assert!(
4423 (analytic - numeric).abs() <= tol,
4424 "probe={probe:?} analytic={analytic} numeric={numeric} tol={tol}"
4425 );
4426 }
4427 }
4428
4429 #[test]
4430 fn tbptt_sgd_step_reduces_mean_segment_loss() {
4431 let cfg = test_cfg();
4432 cfg.validate().expect("valid test config");
4433 let mut model = Model::new_random(cfg.clone(), 0x1234_5678).expect("random model");
4434 let steps = [(0u32, 1u8), (1, 2), (2, 3), (3, 4)];
4435 let before = segment_loss(&model, &cfg, &steps);
4436
4437 let mut scratch = ScratchBuffers::new(&cfg);
4438 let start_state = model.new_state();
4439 let mut state = start_state.clone();
4440 let mut segment_steps = Vec::with_capacity(steps.len());
4441 for &(input, target) in &steps {
4442 let logits = model.forward(&mut scratch, input, &mut state);
4443 segment_steps.push((input, target, softmax_pdf(logits)));
4444 }
4445
4446 let mut live_state = model.new_state();
4447 let mut adam_t = 0usize;
4448 let scope = TrainScopeMask {
4449 embed: true,
4450 layer_norm: true,
4451 mixer_conv: true,
4452 mixer_ssm: true,
4453 mixer_proj: true,
4454 head: true,
4455 bias: false,
4456 };
4457
4458 model
4459 .online_train_segment_tbptt(
4460 &mut scratch,
4461 &start_state,
4462 &segment_steps,
4463 scope,
4464 OptimizerKind::Sgd,
4465 8e-4,
4466 0.0,
4467 2,
4468 &mut adam_t,
4469 None,
4470 None,
4471 None,
4472 None,
4473 &mut live_state,
4474 )
4475 .expect("tbptt sgd step");
4476
4477 let after = segment_loss(&model, &cfg, &steps);
4478 assert!(
4479 after < before,
4480 "expected SGD TBPTT step to reduce mean loss: before={before} after={after}"
4481 );
4482 }
4483}