1use crate::backends::llm_policy::OptimizerKind;
7use anyhow::{Context, Result, bail};
8use serde_json::json;
9use std::fs::File;
10use std::io::Write;
11use std::path::Path;
12use std::time::Instant;
13use wide::f32x8;
14
15use super::kernel;
16use super::profiling::{NullProfiler, ProfilerSink};
17use super::tensor::Tensor1D;
18use super::weights::Weights;
19
20#[derive(Debug, Clone)]
22pub struct Config {
23 pub vocab_size: usize,
25 pub hidden_size: usize,
27 pub num_layers: usize,
29 pub num_heads: usize,
31 pub head_dim: usize,
33 pub intermediate_size: usize,
35 pub layer_norm_eps: f32,
37 pub group_norm_eps: f32,
39
40 pub decay_low_rank: usize, pub a_low_rank: usize,
44 pub v_low_rank: usize,
46 pub g_low_rank: usize,
48}
49
50impl Default for Config {
51 fn default() -> Self {
52 Self {
53 vocab_size: 256,
54 hidden_size: 256,
55 num_layers: 12,
56 num_heads: 4, head_dim: 64,
58 intermediate_size: 1024,
59 layer_norm_eps: 1e-5,
60 group_norm_eps: 64e-5,
61 decay_low_rank: 32,
62 a_low_rank: 32,
63 v_low_rank: 32,
64 g_low_rank: 64,
65 }
66 }
67}
68
69impl Config {
70 pub fn validate(&self) -> Result<()> {
72 if self.vocab_size == 0 {
73 bail!("rwkv7 vocab_size must be > 0");
74 }
75 if self.head_dim != 64 {
76 bail!("rwkv7 head_dim must be 64 for current kernels");
77 }
78 if self.hidden_size != self.num_heads * self.head_dim {
79 bail!(
80 "rwkv7 hidden_size must equal num_heads * head_dim ({} != {} * {})",
81 self.hidden_size,
82 self.num_heads,
83 self.head_dim
84 );
85 }
86 if self.num_layers == 0 {
87 bail!("rwkv7 num_layers must be > 0");
88 }
89 if self.intermediate_size == 0 {
90 bail!("rwkv7 intermediate_size must be > 0");
91 }
92 Ok(())
93 }
94}
95
96#[derive(Clone)]
98pub struct LayerState {
99 pub att_x_prev: Tensor1D,
101 pub att_state: Tensor1D, pub ffn_x_prev: Tensor1D,
105}
106
107impl LayerState {
108 fn new(cfg: &Config) -> Self {
109 let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
110 Self {
111 att_x_prev: Tensor1D::zeros(cfg.hidden_size),
112 att_state: Tensor1D::zeros(state_size),
113 ffn_x_prev: Tensor1D::zeros(cfg.hidden_size),
114 }
115 }
116}
117
118#[derive(Clone)]
120pub struct State {
121 pub layers: Vec<LayerState>,
123 pub v_first: Tensor1D,
125 pub v_first_set: bool,
127}
128
129impl State {
130 pub fn new(cfg: &Config) -> Self {
132 Self {
133 layers: (0..cfg.num_layers).map(|_| LayerState::new(cfg)).collect(),
134 v_first: Tensor1D::zeros(cfg.hidden_size),
135 v_first_set: false,
136 }
137 }
138
139 pub fn reset(&mut self) {
141 self.v_first_set = false;
142 self.v_first.zero();
143 for layer in &mut self.layers {
144 layer.att_x_prev.zero();
145 layer.att_state.zero();
146 layer.ffn_x_prev.zero();
147 }
148 }
149}
150
151#[derive(Clone)]
153struct AttentionWeights {
154 x_r: Tensor1D,
156 x_w: Tensor1D,
157 x_k: Tensor1D,
158 x_v: Tensor1D,
159 x_a: Tensor1D,
160 x_g: Tensor1D,
161
162 rkv_proj: Tensor1D,
165
166 o_proj: Tensor1D,
168
169 w1: Tensor1D, w2: Tensor1D, w0: Tensor1D, a1: Tensor1D, a2: Tensor1D, a0: Tensor1D, v1: Option<Tensor1D>, v2: Option<Tensor1D>, v0: Option<Tensor1D>, g1: Tensor1D, g2: Tensor1D, k_k: Tensor1D, k_a: Tensor1D, r_k: Tensor1D, g_norm_w: Tensor1D, g_norm_b: Tensor1D, }
197
198#[derive(Clone)]
200struct FfnWeights {
201 x_k: Tensor1D, key_w: Tensor1D, value_w: Tensor1D, }
205
206#[derive(Clone)]
208struct BlockWeights {
209 pre_norm_w: Option<Tensor1D>,
211 pre_norm_b: Option<Tensor1D>,
212
213 attn_norm_w: Tensor1D,
215 attn_norm_b: Tensor1D,
216
217 ffn_norm_w: Tensor1D,
219 ffn_norm_b: Tensor1D,
220
221 attn: AttentionWeights,
222 ffn: FfnWeights,
223}
224
225#[derive(Clone)]
227pub struct Model {
228 cfg: Config,
229
230 embeddings: Tensor1D,
232
233 ln_out_w: Tensor1D,
235 ln_out_b: Tensor1D,
236
237 lm_head: Tensor1D,
239
240 blocks: Vec<BlockWeights>,
242}
243
244#[derive(Clone)]
245struct AdamTensorState {
246 m: Tensor1D,
247 v: Tensor1D,
248}
249
250impl AdamTensorState {
251 #[inline]
252 fn new(len: usize) -> Self {
253 Self {
254 m: Tensor1D::zeros(len),
255 v: Tensor1D::zeros(len),
256 }
257 }
258}
259
260#[derive(Clone)]
261struct AttentionAdamState {
262 x_r: AdamTensorState,
263 x_w: AdamTensorState,
264 x_k: AdamTensorState,
265 x_v: AdamTensorState,
266 x_a: AdamTensorState,
267 x_g: AdamTensorState,
268 rkv_proj: AdamTensorState,
269 o_proj: AdamTensorState,
270 w1: AdamTensorState,
271 w2: AdamTensorState,
272 w0: AdamTensorState,
273 a1: AdamTensorState,
274 a2: AdamTensorState,
275 a0: AdamTensorState,
276 v1: Option<AdamTensorState>,
277 v2: Option<AdamTensorState>,
278 v0: Option<AdamTensorState>,
279 g1: AdamTensorState,
280 g2: AdamTensorState,
281 k_k: AdamTensorState,
282 k_a: AdamTensorState,
283 r_k: AdamTensorState,
284 g_norm_w: AdamTensorState,
285 g_norm_b: AdamTensorState,
286}
287
288#[derive(Clone)]
289struct FfnAdamState {
290 x_k: AdamTensorState,
291 key_w: AdamTensorState,
292 value_w: AdamTensorState,
293}
294
295#[derive(Clone)]
296struct BlockAdamState {
297 pre_norm_w: Option<AdamTensorState>,
298 pre_norm_b: Option<AdamTensorState>,
299 attn_norm_w: AdamTensorState,
300 attn_norm_b: AdamTensorState,
301 ffn_norm_w: AdamTensorState,
302 ffn_norm_b: AdamTensorState,
303 attn: AttentionAdamState,
304 ffn: FfnAdamState,
305}
306
307#[derive(Clone)]
308pub struct FullAdamState {
310 embeddings: AdamTensorState,
311 ln_out_w: AdamTensorState,
312 ln_out_b: AdamTensorState,
313 lm_head: AdamTensorState,
314 blocks: Vec<BlockAdamState>,
315}
316
317#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
318pub struct TrainScopeMask {
320 pub embed: bool,
322 pub pre_norm: bool,
324 pub attn_norm: bool,
326 pub ffn_norm: bool,
328 pub attn: bool,
330 pub ffn: bool,
332 pub head: bool,
334 pub bias: bool,
336}
337
338impl TrainScopeMask {
339 #[inline]
340 pub fn all() -> Self {
342 Self {
343 embed: true,
344 pre_norm: true,
345 attn_norm: true,
346 ffn_norm: true,
347 attn: true,
348 ffn: true,
349 head: true,
350 bias: true,
351 }
352 }
353
354 #[inline]
355 pub fn trains_non_head_params(&self) -> bool {
357 self.embed || self.pre_norm || self.attn_norm || self.ffn_norm || self.attn || self.ffn
358 }
359
360 #[inline]
361 pub fn trains_any_params(&self) -> bool {
363 self.trains_non_head_params() || self.head || self.bias
364 }
365}
366
367#[derive(Clone)]
368struct AttentionGradState {
369 x_r: Tensor1D,
370 x_w: Tensor1D,
371 x_k: Tensor1D,
372 x_v: Tensor1D,
373 x_a: Tensor1D,
374 x_g: Tensor1D,
375 rkv_proj: Tensor1D,
376 o_proj: Tensor1D,
377 w1: Tensor1D,
378 w2: Tensor1D,
379 w0: Tensor1D,
380 a1: Tensor1D,
381 a2: Tensor1D,
382 a0: Tensor1D,
383 v1: Option<Tensor1D>,
384 v2: Option<Tensor1D>,
385 v0: Option<Tensor1D>,
386 g1: Tensor1D,
387 g2: Tensor1D,
388 k_k: Tensor1D,
389 k_a: Tensor1D,
390 r_k: Tensor1D,
391 g_norm_w: Tensor1D,
392 g_norm_b: Tensor1D,
393}
394
395#[derive(Clone)]
396struct FfnGradState {
397 x_k: Tensor1D,
398 key_w: Tensor1D,
399 value_w: Tensor1D,
400}
401
402#[derive(Clone)]
403struct BlockGradState {
404 pre_norm_w: Option<Tensor1D>,
405 pre_norm_b: Option<Tensor1D>,
406 attn_norm_w: Tensor1D,
407 attn_norm_b: Tensor1D,
408 ffn_norm_w: Tensor1D,
409 ffn_norm_b: Tensor1D,
410 attn: AttentionGradState,
411 ffn: FfnGradState,
412}
413
414#[derive(Clone)]
415struct FullGradState {
416 embeddings: Tensor1D,
417 ln_out_w: Tensor1D,
418 ln_out_b: Tensor1D,
419 lm_head: Tensor1D,
420 blocks: Vec<BlockGradState>,
421}
422
423struct AdamStep {
424 lr: f32,
425 clip: f32,
426 b1: f32,
427 b2: f32,
428 eps: f32,
429 bias_corr1: f32,
430 bias_corr2: f32,
431}
432
433#[derive(Clone)]
434struct LayerTrainTrace {
435 x_in: Tensor1D,
436 x_after_pre: Tensor1D,
437 attn_norm: Tensor1D,
438 att_x_prev_old: Tensor1D,
439 ffn_x_prev_old: Tensor1D,
440 att_state_old: Tensor1D,
441 xr: Tensor1D,
442 xw: Tensor1D,
443 xk: Tensor1D,
444 xv: Tensor1D,
445 xa: Tensor1D,
446 xg: Tensor1D,
447 r: Tensor1D,
448 k_pre: Tensor1D,
449 k: Tensor1D,
450 v_pre: Tensor1D,
451 v: Tensor1D,
452 nu: Tensor1D,
453 w_hidden: Tensor1D,
454 w_pre: Tensor1D,
455 w_sigmoid: Tensor1D,
456 w_decay: Tensor1D,
457 a_hidden: Tensor1D,
458 a: Tensor1D,
459 g_hidden: Tensor1D,
460 g: Tensor1D,
461 kk_pre: Tensor1D,
462 kk: Tensor1D,
463 y_wkv: Tensor1D,
464 y_gn: Tensor1D,
465 alpha: Tensor1D,
466 y_head: Tensor1D,
467 y_gate: Tensor1D,
468 att_out: Tensor1D,
469 x_after_attn: Tensor1D,
470 ffn_norm: Tensor1D,
471 ffn_xk: Tensor1D,
472 ffn_pre: Tensor1D,
473 ffn_k: Tensor1D,
474 ffn_out: Tensor1D,
475 x_out: Tensor1D,
476 v_hidden: Tensor1D,
477 uses_v_residual: bool,
478}
479
480impl LayerTrainTrace {
481 fn new(cfg: &Config) -> Self {
482 let c = cfg.hidden_size;
483 let i = cfg.intermediate_size;
484 let state = cfg.num_heads * cfg.head_dim * cfg.head_dim;
485 Self {
486 x_in: Tensor1D::zeros(c),
487 x_after_pre: Tensor1D::zeros(c),
488 attn_norm: Tensor1D::zeros(c),
489 att_x_prev_old: Tensor1D::zeros(c),
490 ffn_x_prev_old: Tensor1D::zeros(c),
491 att_state_old: Tensor1D::zeros(state),
492 xr: Tensor1D::zeros(c),
493 xw: Tensor1D::zeros(c),
494 xk: Tensor1D::zeros(c),
495 xv: Tensor1D::zeros(c),
496 xa: Tensor1D::zeros(c),
497 xg: Tensor1D::zeros(c),
498 r: Tensor1D::zeros(c),
499 k_pre: Tensor1D::zeros(c),
500 k: Tensor1D::zeros(c),
501 v_pre: Tensor1D::zeros(c),
502 v: Tensor1D::zeros(c),
503 nu: Tensor1D::zeros(c),
504 w_hidden: Tensor1D::zeros(cfg.decay_low_rank),
505 w_pre: Tensor1D::zeros(c),
506 w_sigmoid: Tensor1D::zeros(c),
507 w_decay: Tensor1D::zeros(c),
508 a_hidden: Tensor1D::zeros(cfg.a_low_rank),
509 a: Tensor1D::zeros(c),
510 g_hidden: Tensor1D::zeros(cfg.g_low_rank),
511 g: Tensor1D::zeros(c),
512 kk_pre: Tensor1D::zeros(c),
513 kk: Tensor1D::zeros(c),
514 y_wkv: Tensor1D::zeros(c),
515 y_gn: Tensor1D::zeros(c),
516 alpha: Tensor1D::zeros(cfg.num_heads),
517 y_head: Tensor1D::zeros(c),
518 y_gate: Tensor1D::zeros(c),
519 att_out: Tensor1D::zeros(c),
520 x_after_attn: Tensor1D::zeros(c),
521 ffn_norm: Tensor1D::zeros(c),
522 ffn_xk: Tensor1D::zeros(c),
523 ffn_pre: Tensor1D::zeros(i),
524 ffn_k: Tensor1D::zeros(i),
525 ffn_out: Tensor1D::zeros(c),
526 x_out: Tensor1D::zeros(c),
527 v_hidden: Tensor1D::zeros(cfg.v_low_rank.max(1)),
528 uses_v_residual: false,
529 }
530 }
531}
532
533#[derive(Clone)]
534struct TokenTrainTrace {
535 token: usize,
536 x: Tensor1D,
537 x_normed: Tensor1D,
538 v_first: Tensor1D,
539 layers: Vec<LayerTrainTrace>,
540}
541
542impl TokenTrainTrace {
543 fn from_scratch(scratch: &ScratchBuffers) -> Self {
544 Self {
545 token: scratch.train_token,
546 x: scratch.x.clone(),
547 x_normed: scratch.x_normed.clone(),
548 v_first: scratch.train_v_first.clone(),
549 layers: scratch.train_trace_layers.clone(),
550 }
551 }
552}
553
554#[derive(Clone)]
555struct LayerRecurrentGradState {
556 att_x_prev: Tensor1D,
557 att_state: Tensor1D,
558 ffn_x_prev: Tensor1D,
559}
560
561impl LayerRecurrentGradState {
562 fn new(cfg: &Config) -> Self {
563 let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
564 Self {
565 att_x_prev: Tensor1D::zeros(cfg.hidden_size),
566 att_state: Tensor1D::zeros(state_size),
567 ffn_x_prev: Tensor1D::zeros(cfg.hidden_size),
568 }
569 }
570}
571
572#[derive(Clone)]
573struct RecurrentGradState {
574 layers: Vec<LayerRecurrentGradState>,
575}
576
577impl RecurrentGradState {
578 fn new(cfg: &Config) -> Self {
579 Self {
580 layers: (0..cfg.num_layers)
581 .map(|_| LayerRecurrentGradState::new(cfg))
582 .collect(),
583 }
584 }
585
586 fn zero(&mut self) {
587 for layer in &mut self.layers {
588 layer.att_x_prev.zero();
589 layer.att_state.zero();
590 layer.ffn_x_prev.zero();
591 }
592 }
593}
594
595#[derive(Clone)]
597pub struct ScratchBuffers {
598 x: Tensor1D, x_normed: Tensor1D, xr: Tensor1D, xw: Tensor1D, xk: Tensor1D, xv: Tensor1D, xa: Tensor1D, xg: Tensor1D, r: Tensor1D, k: Tensor1D, v: Tensor1D, w_lora_tmp: Tensor1D, w_decay: Tensor1D, a: Tensor1D, g: Tensor1D, kk: Tensor1D, y: Tensor1D, att_out: Tensor1D, ffn_k: Tensor1D, ffn_out: Tensor1D, logits: Tensor1D, grad_x: Tensor1D,
620 grad_x2: Tensor1D,
621 grad_x3: Tensor1D,
622 grad_x4: Tensor1D,
623 grad_x5: Tensor1D,
624 grad_x6: Tensor1D,
625 grad_v_first: Tensor1D,
626 grad_param: Tensor1D,
627 grad_param2: Tensor1D,
628 grad_saved: Tensor1D,
629 grad_ffn: Tensor1D,
630 grad_ffn2: Tensor1D,
631 grad_low_rank: Tensor1D,
632 grad_low_rank2: Tensor1D,
633 grad_att_state: Tensor1D,
634 grad_logits: Tensor1D,
635 train_trace_layers: Vec<LayerTrainTrace>,
636 train_token: usize,
637 train_v_first: Tensor1D,
638 train_trace_valid: bool,
639 capture_train_trace: bool,
640}
641
642impl ScratchBuffers {
643 pub fn new(cfg: &Config) -> Self {
645 let c = cfg.hidden_size;
646 let i = cfg.intermediate_size;
647 let v = cfg.vocab_size;
648 let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
649 let d_rank = cfg
650 .decay_low_rank
651 .max(cfg.a_low_rank)
652 .max(cfg.v_low_rank)
653 .max(cfg.g_low_rank)
654 .max(64);
655 let mut train_trace_layers = Vec::with_capacity(cfg.num_layers);
656 for _ in 0..cfg.num_layers {
657 train_trace_layers.push(LayerTrainTrace::new(cfg));
658 }
659
660 Self {
661 x: Tensor1D::zeros(c),
662 x_normed: Tensor1D::zeros(c),
663 xr: Tensor1D::zeros(c),
664 xw: Tensor1D::zeros(c),
665 xk: Tensor1D::zeros(c),
666 xv: Tensor1D::zeros(c),
667 xa: Tensor1D::zeros(c),
668 xg: Tensor1D::zeros(c),
669 r: Tensor1D::zeros(c),
670 k: Tensor1D::zeros(c),
671 v: Tensor1D::zeros(c),
672 w_lora_tmp: Tensor1D::zeros(d_rank),
673 w_decay: Tensor1D::zeros(c),
674 a: Tensor1D::zeros(c),
675 g: Tensor1D::zeros(c),
676 kk: Tensor1D::zeros(c),
677 y: Tensor1D::zeros(c),
678 att_out: Tensor1D::zeros(c),
679 ffn_k: Tensor1D::zeros(i),
680 ffn_out: Tensor1D::zeros(c),
681 logits: Tensor1D::zeros(v),
682 grad_x: Tensor1D::zeros(c),
683 grad_x2: Tensor1D::zeros(c),
684 grad_x3: Tensor1D::zeros(c),
685 grad_x4: Tensor1D::zeros(c),
686 grad_x5: Tensor1D::zeros(c),
687 grad_x6: Tensor1D::zeros(c),
688 grad_v_first: Tensor1D::zeros(c),
689 grad_param: Tensor1D::zeros(c),
690 grad_param2: Tensor1D::zeros(c),
691 grad_saved: Tensor1D::zeros(c),
692 grad_ffn: Tensor1D::zeros(i),
693 grad_ffn2: Tensor1D::zeros(i),
694 grad_low_rank: Tensor1D::zeros(d_rank),
695 grad_low_rank2: Tensor1D::zeros(d_rank),
696 grad_att_state: Tensor1D::zeros(state_size),
697 grad_logits: Tensor1D::zeros(v),
698 train_trace_layers,
699 train_token: 0,
700 train_v_first: Tensor1D::zeros(c),
701 train_trace_valid: false,
702 capture_train_trace: false,
703 }
704 }
705
706 #[inline]
708 pub fn lm_head_input(&self) -> &[f32] {
709 self.x_normed.as_slice()
710 }
711
712 #[inline]
713 pub fn logits(&self) -> &[f32] {
715 self.logits.as_slice()
716 }
717
718 #[inline]
720 pub fn set_lm_head_input(&mut self, value: &[f32]) {
721 self.x_normed.as_mut_slice().copy_from_slice(value);
722 }
723
724 #[inline]
726 pub fn set_capture_train_trace(&mut self, enabled: bool) {
727 self.capture_train_trace = enabled;
728 if !enabled {
729 self.train_trace_valid = false;
730 }
731 }
732
733 #[inline]
735 pub fn has_train_trace(&self) -> bool {
736 self.train_trace_valid
737 }
738}
739
740impl Model {
741 fn tensor_from(weights: &Weights, name: &str) -> Result<Tensor1D> {
742 Ok(Tensor1D::from_vec(weights.require(name)?.data().to_vec()))
743 }
744
745 fn optional_tensor_from(weights: &Weights, name: &str) -> Option<Tensor1D> {
746 weights
747 .get(name)
748 .map(|tensor| Tensor1D::from_vec(tensor.data().to_vec()))
749 }
750
751 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
753 let weights = Weights::load(path.as_ref()).context("Failed to load model weights")?;
754
755 let emb = weights.require("model.embeddings.weight")?;
757 let vocab_size = emb.shape()[0];
758 let hidden_size = emb.shape()[1];
759
760 let num_heads = hidden_size / 64; let head_dim = 64;
762
763 let mut num_layers = 0;
765 while weights
766 .get(&format!("model.layers.{}.attn.r_proj.weight", num_layers))
767 .is_some()
768 {
769 num_layers += 1;
770 }
771
772 let ffn_key = weights.require("model.layers.0.ffn.key.weight")?;
774 let intermediate_size = ffn_key.shape()[0];
775
776 let w1 = weights.require("model.layers.0.attn.w_lora.lora.0.weight")?;
778 let decay_low_rank = w1.shape()[0];
779
780 let a1 = weights.require("model.layers.0.attn.a_lora.lora.0.weight")?;
781 let a_low_rank = a1.shape()[0];
782
783 let g1 = weights.require("model.layers.0.attn.g_lora.lora.0.weight")?;
784 let g_low_rank = g1.shape()[0];
785
786 let v_low_rank = if num_layers > 1 {
788 if let Some(v1) = weights.get("model.layers.1.attn.v_lora.lora.0.weight") {
789 v1.shape()[0]
790 } else {
791 32
792 }
793 } else {
794 32
795 };
796
797 let cfg = Config {
798 vocab_size,
799 hidden_size,
800 num_layers,
801 num_heads,
802 head_dim,
803 intermediate_size,
804 layer_norm_eps: 1e-5,
805 group_norm_eps: 64e-5,
806 decay_low_rank,
807 a_low_rank,
808 v_low_rank,
809 g_low_rank,
810 };
811
812 let embeddings = Self::tensor_from(&weights, "model.embeddings.weight")?;
814
815 let ln_out_w = Self::tensor_from(&weights, "model.norm.weight")?;
817 let ln_out_b = Self::tensor_from(&weights, "model.norm.bias")?;
818
819 let lm_head = Self::tensor_from(&weights, "lm_head.weight")?;
821
822 let mut blocks = Vec::with_capacity(num_layers);
824 for i in 0..num_layers {
825 let prefix = format!("model.layers.{}", i);
826
827 let (pre_norm_w, pre_norm_b) = if i == 0 {
829 (
830 Some(Self::tensor_from(
831 &weights,
832 &format!("{}.pre_norm.weight", prefix),
833 )?),
834 Some(Self::tensor_from(
835 &weights,
836 &format!("{}.pre_norm.bias", prefix),
837 )?),
838 )
839 } else {
840 (None, None)
841 };
842
843 let attn_norm_w = Self::tensor_from(&weights, &format!("{}.attn_norm.weight", prefix))?;
845 let attn_norm_b = Self::tensor_from(&weights, &format!("{}.attn_norm.bias", prefix))?;
846 let ffn_norm_w = Self::tensor_from(&weights, &format!("{}.ffn_norm.weight", prefix))?;
847 let ffn_norm_b = Self::tensor_from(&weights, &format!("{}.ffn_norm.bias", prefix))?;
848
849 let r_proj_data = weights
852 .require(&format!("{}.attn.r_proj.weight", prefix))?
853 .data();
854 let k_proj_data = weights
855 .require(&format!("{}.attn.k_proj.weight", prefix))?
856 .data();
857 let v_proj_data = weights
858 .require(&format!("{}.attn.v_proj.weight", prefix))?
859 .data();
860
861 let proj_size = hidden_size * hidden_size;
863 let mut rkv_proj = Tensor1D::zeros(3 * proj_size);
864 rkv_proj.as_mut_slice()[0..proj_size].copy_from_slice(r_proj_data);
865 rkv_proj.as_mut_slice()[proj_size..2 * proj_size].copy_from_slice(k_proj_data);
866 rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size].copy_from_slice(v_proj_data);
867
868 let attn = AttentionWeights {
869 x_r: Self::tensor_from(&weights, &format!("{}.attn.x_r", prefix))?,
870 x_w: Self::tensor_from(&weights, &format!("{}.attn.x_w", prefix))?,
871 x_k: Self::tensor_from(&weights, &format!("{}.attn.x_k", prefix))?,
872 x_v: Self::tensor_from(&weights, &format!("{}.attn.x_v", prefix))?,
873 x_a: Self::tensor_from(&weights, &format!("{}.attn.x_a", prefix))?,
874 x_g: Self::tensor_from(&weights, &format!("{}.attn.x_g", prefix))?,
875
876 rkv_proj,
877 o_proj: Self::tensor_from(&weights, &format!("{}.attn.o_proj.weight", prefix))?,
878
879 w1: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.0.weight", prefix))?,
880 w2: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.2.weight", prefix))?,
881 w0: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.2.bias", prefix))?,
882
883 a1: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.0.weight", prefix))?,
884 a2: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.2.weight", prefix))?,
885 a0: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.2.bias", prefix))?,
886
887 v1: Self::optional_tensor_from(
888 &weights,
889 &format!("{}.attn.v_lora.lora.0.weight", prefix),
890 ),
891 v2: Self::optional_tensor_from(
892 &weights,
893 &format!("{}.attn.v_lora.lora.2.weight", prefix),
894 ),
895 v0: Self::optional_tensor_from(
896 &weights,
897 &format!("{}.attn.v_lora.lora.2.bias", prefix),
898 ),
899
900 g1: Self::tensor_from(&weights, &format!("{}.attn.g_lora.lora.0.weight", prefix))?,
901 g2: Self::tensor_from(&weights, &format!("{}.attn.g_lora.lora.2.weight", prefix))?,
902
903 k_k: Self::tensor_from(&weights, &format!("{}.attn.k_k", prefix))?,
904 k_a: Self::tensor_from(&weights, &format!("{}.attn.k_a", prefix))?,
905 r_k: Self::tensor_from(&weights, &format!("{}.attn.r_k", prefix))?,
906
907 g_norm_w: Self::tensor_from(&weights, &format!("{}.attn.g_norm.weight", prefix))?,
908 g_norm_b: Self::tensor_from(&weights, &format!("{}.attn.g_norm.bias", prefix))?,
909 };
910
911 let ffn = FfnWeights {
913 x_k: Self::tensor_from(&weights, &format!("{}.ffn.x_k", prefix))?,
914 key_w: Self::tensor_from(&weights, &format!("{}.ffn.key.weight", prefix))?,
915 value_w: Self::tensor_from(&weights, &format!("{}.ffn.value.weight", prefix))?,
916 };
917
918 blocks.push(BlockWeights {
919 pre_norm_w,
920 pre_norm_b,
921 attn_norm_w,
922 attn_norm_b,
923 ffn_norm_w,
924 ffn_norm_b,
925 attn,
926 ffn,
927 });
928 }
929
930 Ok(Self {
931 cfg,
932 embeddings,
933 ln_out_w,
934 ln_out_b,
935 lm_head,
936 blocks,
937 })
938 }
939
940 pub fn new_random(cfg: Config, seed: u64) -> Result<Self> {
942 cfg.validate()?;
943
944 let mut rng = RwkvRng::new(seed);
945 let c = cfg.hidden_size;
946 let v = cfg.vocab_size;
947 let i = cfg.intermediate_size;
948 let d_w = cfg.decay_low_rank;
949 let d_a = cfg.a_low_rank;
950 let d_v = cfg.v_low_rank;
951 let d_g = cfg.g_low_rank;
952
953 let mut embeddings = Tensor1D::zeros(v * c);
954 init_uniform(&mut embeddings, &mut rng, 0.02);
955
956 let mut ln_out_w = Tensor1D::zeros(c);
957 let mut ln_out_b = Tensor1D::zeros(c);
958 init_const(&mut ln_out_w, 1.0);
959 init_const(&mut ln_out_b, 0.0);
960
961 let mut lm_head = Tensor1D::zeros(v * c);
962 init_uniform(&mut lm_head, &mut rng, 0.02);
963
964 let mut blocks = Vec::with_capacity(cfg.num_layers);
965 for layer_idx in 0..cfg.num_layers {
966 let (pre_norm_w, pre_norm_b) = if layer_idx == 0 {
967 let mut w = Tensor1D::zeros(c);
968 let mut b = Tensor1D::zeros(c);
969 init_const(&mut w, 1.0);
970 init_const(&mut b, 0.0);
971 (Some(w), Some(b))
972 } else {
973 (None, None)
974 };
975
976 let mut attn_norm_w = Tensor1D::zeros(c);
977 let mut attn_norm_b = Tensor1D::zeros(c);
978 init_const(&mut attn_norm_w, 1.0);
979 init_const(&mut attn_norm_b, 0.0);
980
981 let mut ffn_norm_w = Tensor1D::zeros(c);
982 let mut ffn_norm_b = Tensor1D::zeros(c);
983 init_const(&mut ffn_norm_w, 1.0);
984 init_const(&mut ffn_norm_b, 0.0);
985
986 let mut rkv_proj = Tensor1D::zeros(3 * c * c);
987 init_uniform(&mut rkv_proj, &mut rng, 0.02);
988
989 let mut o_proj = Tensor1D::zeros(c * c);
990 init_uniform(&mut o_proj, &mut rng, 0.02);
991
992 let mut w1 = Tensor1D::zeros(d_w * c);
993 let mut w2 = Tensor1D::zeros(c * d_w);
994 let mut w0 = Tensor1D::zeros(c);
995 init_uniform(&mut w1, &mut rng, 0.02);
996 init_uniform(&mut w2, &mut rng, 0.02);
997 init_const(&mut w0, 0.0);
998
999 let mut a1 = Tensor1D::zeros(d_a * c);
1000 let mut a2 = Tensor1D::zeros(c * d_a);
1001 let mut a0 = Tensor1D::zeros(c);
1002 init_uniform(&mut a1, &mut rng, 0.02);
1003 init_uniform(&mut a2, &mut rng, 0.02);
1004 init_const(&mut a0, 0.0);
1005
1006 let (v1, v2, v0) = if layer_idx == 0 {
1007 (None, None, None)
1008 } else {
1009 let mut v1 = Tensor1D::zeros(d_v * c);
1010 let mut v2 = Tensor1D::zeros(c * d_v);
1011 let mut v0 = Tensor1D::zeros(c);
1012 init_uniform(&mut v1, &mut rng, 0.02);
1013 init_uniform(&mut v2, &mut rng, 0.02);
1014 init_const(&mut v0, 0.0);
1015 (Some(v1), Some(v2), Some(v0))
1016 };
1017
1018 let mut g1 = Tensor1D::zeros(d_g * c);
1019 let mut g2 = Tensor1D::zeros(c * d_g);
1020 init_uniform(&mut g1, &mut rng, 0.02);
1021 init_uniform(&mut g2, &mut rng, 0.02);
1022
1023 let mut x_r = Tensor1D::zeros(c);
1024 let mut x_w = Tensor1D::zeros(c);
1025 let mut x_k = Tensor1D::zeros(c);
1026 let mut x_v = Tensor1D::zeros(c);
1027 let mut x_a = Tensor1D::zeros(c);
1028 let mut x_g = Tensor1D::zeros(c);
1029 init_centered(&mut x_r, &mut rng, 0.5, 0.02);
1030 init_centered(&mut x_w, &mut rng, 0.5, 0.02);
1031 init_centered(&mut x_k, &mut rng, 0.5, 0.02);
1032 init_centered(&mut x_v, &mut rng, 0.5, 0.02);
1033 init_centered(&mut x_a, &mut rng, 0.5, 0.02);
1034 init_centered(&mut x_g, &mut rng, 0.5, 0.02);
1035
1036 let mut k_k = Tensor1D::zeros(c);
1037 let mut k_a = Tensor1D::zeros(c);
1038 let mut r_k = Tensor1D::zeros(c);
1039 init_const(&mut k_k, 1.0);
1040 init_const(&mut k_a, 1.0);
1041 init_const(&mut r_k, 1.0);
1042
1043 let mut g_norm_w = Tensor1D::zeros(c);
1044 let mut g_norm_b = Tensor1D::zeros(c);
1045 init_const(&mut g_norm_w, 1.0);
1046 init_const(&mut g_norm_b, 0.0);
1047
1048 let attn = AttentionWeights {
1049 x_r,
1050 x_w,
1051 x_k,
1052 x_v,
1053 x_a,
1054 x_g,
1055 rkv_proj,
1056 o_proj,
1057 w1,
1058 w2,
1059 w0,
1060 a1,
1061 a2,
1062 a0,
1063 v1,
1064 v2,
1065 v0,
1066 g1,
1067 g2,
1068 k_k,
1069 k_a,
1070 r_k,
1071 g_norm_w,
1072 g_norm_b,
1073 };
1074
1075 let mut ffn_x_k = Tensor1D::zeros(c);
1076 init_centered(&mut ffn_x_k, &mut rng, 0.5, 0.02);
1077 let mut key_w = Tensor1D::zeros(i * c);
1078 let mut value_w = Tensor1D::zeros(c * i);
1079 init_uniform(&mut key_w, &mut rng, 0.02);
1080 init_uniform(&mut value_w, &mut rng, 0.02);
1081
1082 let ffn = FfnWeights {
1083 x_k: ffn_x_k,
1084 key_w,
1085 value_w,
1086 };
1087
1088 blocks.push(BlockWeights {
1089 pre_norm_w,
1090 pre_norm_b,
1091 attn_norm_w,
1092 attn_norm_b,
1093 ffn_norm_w,
1094 ffn_norm_b,
1095 attn,
1096 ffn,
1097 });
1098 }
1099
1100 Ok(Self {
1101 cfg,
1102 embeddings,
1103 ln_out_w,
1104 ln_out_b,
1105 lm_head,
1106 blocks,
1107 })
1108 }
1109
1110 pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<()> {
1112 #[derive(Clone)]
1113 struct TensorRec {
1114 name: String,
1115 shape: Vec<usize>,
1116 data: Vec<f32>,
1117 }
1118
1119 let c = self.cfg.hidden_size;
1120 let v = self.cfg.vocab_size;
1121 let i = self.cfg.intermediate_size;
1122 let d_w = self.cfg.decay_low_rank;
1123 let d_a = self.cfg.a_low_rank;
1124 let d_v = self.cfg.v_low_rank;
1125 let d_g = self.cfg.g_low_rank;
1126
1127 let mut recs = Vec::<TensorRec>::new();
1128 let push = |recs: &mut Vec<TensorRec>, name: String, shape: Vec<usize>, src: &Tensor1D| {
1129 recs.push(TensorRec {
1130 name,
1131 shape,
1132 data: src.as_slice().to_vec(),
1133 });
1134 };
1135
1136 push(
1137 &mut recs,
1138 "model.embeddings.weight".to_string(),
1139 vec![v, c],
1140 &self.embeddings,
1141 );
1142 push(
1143 &mut recs,
1144 "model.norm.weight".to_string(),
1145 vec![c],
1146 &self.ln_out_w,
1147 );
1148 push(
1149 &mut recs,
1150 "model.norm.bias".to_string(),
1151 vec![c],
1152 &self.ln_out_b,
1153 );
1154 push(
1155 &mut recs,
1156 "lm_head.weight".to_string(),
1157 vec![v, c],
1158 &self.lm_head,
1159 );
1160
1161 for (idx, b) in self.blocks.iter().enumerate() {
1162 let pfx = format!("model.layers.{idx}");
1163 if let (Some(w), Some(bias)) = (&b.pre_norm_w, &b.pre_norm_b) {
1164 push(&mut recs, format!("{pfx}.pre_norm.weight"), vec![c], w);
1165 push(&mut recs, format!("{pfx}.pre_norm.bias"), vec![c], bias);
1166 }
1167
1168 push(
1169 &mut recs,
1170 format!("{pfx}.attn_norm.weight"),
1171 vec![c],
1172 &b.attn_norm_w,
1173 );
1174 push(
1175 &mut recs,
1176 format!("{pfx}.attn_norm.bias"),
1177 vec![c],
1178 &b.attn_norm_b,
1179 );
1180 push(
1181 &mut recs,
1182 format!("{pfx}.ffn_norm.weight"),
1183 vec![c],
1184 &b.ffn_norm_w,
1185 );
1186 push(
1187 &mut recs,
1188 format!("{pfx}.ffn_norm.bias"),
1189 vec![c],
1190 &b.ffn_norm_b,
1191 );
1192
1193 let proj = b.attn.rkv_proj.as_slice();
1194 let proj_size = c * c;
1195 recs.push(TensorRec {
1196 name: format!("{pfx}.attn.r_proj.weight"),
1197 shape: vec![c, c],
1198 data: proj[0..proj_size].to_vec(),
1199 });
1200 recs.push(TensorRec {
1201 name: format!("{pfx}.attn.k_proj.weight"),
1202 shape: vec![c, c],
1203 data: proj[proj_size..2 * proj_size].to_vec(),
1204 });
1205 recs.push(TensorRec {
1206 name: format!("{pfx}.attn.v_proj.weight"),
1207 shape: vec![c, c],
1208 data: proj[2 * proj_size..3 * proj_size].to_vec(),
1209 });
1210
1211 push(
1212 &mut recs,
1213 format!("{pfx}.attn.o_proj.weight"),
1214 vec![c, c],
1215 &b.attn.o_proj,
1216 );
1217 push(&mut recs, format!("{pfx}.attn.x_r"), vec![c], &b.attn.x_r);
1218 push(&mut recs, format!("{pfx}.attn.x_w"), vec![c], &b.attn.x_w);
1219 push(&mut recs, format!("{pfx}.attn.x_k"), vec![c], &b.attn.x_k);
1220 push(&mut recs, format!("{pfx}.attn.x_v"), vec![c], &b.attn.x_v);
1221 push(&mut recs, format!("{pfx}.attn.x_a"), vec![c], &b.attn.x_a);
1222 push(&mut recs, format!("{pfx}.attn.x_g"), vec![c], &b.attn.x_g);
1223
1224 push(
1225 &mut recs,
1226 format!("{pfx}.attn.w_lora.lora.0.weight"),
1227 vec![d_w, c],
1228 &b.attn.w1,
1229 );
1230 push(
1231 &mut recs,
1232 format!("{pfx}.attn.w_lora.lora.2.weight"),
1233 vec![c, d_w],
1234 &b.attn.w2,
1235 );
1236 push(
1237 &mut recs,
1238 format!("{pfx}.attn.w_lora.lora.2.bias"),
1239 vec![c],
1240 &b.attn.w0,
1241 );
1242
1243 push(
1244 &mut recs,
1245 format!("{pfx}.attn.a_lora.lora.0.weight"),
1246 vec![d_a, c],
1247 &b.attn.a1,
1248 );
1249 push(
1250 &mut recs,
1251 format!("{pfx}.attn.a_lora.lora.2.weight"),
1252 vec![c, d_a],
1253 &b.attn.a2,
1254 );
1255 push(
1256 &mut recs,
1257 format!("{pfx}.attn.a_lora.lora.2.bias"),
1258 vec![c],
1259 &b.attn.a0,
1260 );
1261
1262 if let Some(v1) = &b.attn.v1 {
1263 push(
1264 &mut recs,
1265 format!("{pfx}.attn.v_lora.lora.0.weight"),
1266 vec![d_v, c],
1267 v1,
1268 );
1269 }
1270 if let Some(v2) = &b.attn.v2 {
1271 push(
1272 &mut recs,
1273 format!("{pfx}.attn.v_lora.lora.2.weight"),
1274 vec![c, d_v],
1275 v2,
1276 );
1277 }
1278 if let Some(v0) = &b.attn.v0 {
1279 push(
1280 &mut recs,
1281 format!("{pfx}.attn.v_lora.lora.2.bias"),
1282 vec![c],
1283 v0,
1284 );
1285 }
1286
1287 push(
1288 &mut recs,
1289 format!("{pfx}.attn.g_lora.lora.0.weight"),
1290 vec![d_g, c],
1291 &b.attn.g1,
1292 );
1293 push(
1294 &mut recs,
1295 format!("{pfx}.attn.g_lora.lora.2.weight"),
1296 vec![c, d_g],
1297 &b.attn.g2,
1298 );
1299
1300 push(&mut recs, format!("{pfx}.attn.k_k"), vec![c], &b.attn.k_k);
1301 push(&mut recs, format!("{pfx}.attn.k_a"), vec![c], &b.attn.k_a);
1302 push(&mut recs, format!("{pfx}.attn.r_k"), vec![c], &b.attn.r_k);
1303 push(
1304 &mut recs,
1305 format!("{pfx}.attn.g_norm.weight"),
1306 vec![c],
1307 &b.attn.g_norm_w,
1308 );
1309 push(
1310 &mut recs,
1311 format!("{pfx}.attn.g_norm.bias"),
1312 vec![c],
1313 &b.attn.g_norm_b,
1314 );
1315
1316 push(&mut recs, format!("{pfx}.ffn.x_k"), vec![c], &b.ffn.x_k);
1317 push(
1318 &mut recs,
1319 format!("{pfx}.ffn.key.weight"),
1320 vec![i, c],
1321 &b.ffn.key_w,
1322 );
1323 push(
1324 &mut recs,
1325 format!("{pfx}.ffn.value.weight"),
1326 vec![c, i],
1327 &b.ffn.value_w,
1328 );
1329 }
1330
1331 recs.sort_by(|a, b| a.name.cmp(&b.name));
1332 let mut offset = 0usize;
1333 let mut header = serde_json::Map::new();
1334 header.insert("__metadata__".to_string(), json!({}));
1335 for rec in &recs {
1336 let bytes = rec.data.len() * 4;
1337 header.insert(
1338 rec.name.clone(),
1339 json!({
1340 "dtype": "F32",
1341 "shape": rec.shape,
1342 "data_offsets": [offset, offset + bytes]
1343 }),
1344 );
1345 offset += bytes;
1346 }
1347 let header_bytes = serde_json::to_vec(&header)?;
1348 let mut f = File::create(path.as_ref())?;
1349 f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
1350 f.write_all(&header_bytes)?;
1351 for rec in &recs {
1352 for v in &rec.data {
1353 f.write_all(&v.to_le_bytes())?;
1354 }
1355 }
1356 Ok(())
1357 }
1358
1359 pub fn new_full_adam_state(&self) -> FullAdamState {
1361 let mut blocks = Vec::with_capacity(self.blocks.len());
1362 for b in &self.blocks {
1363 blocks.push(BlockAdamState {
1364 pre_norm_w: b.pre_norm_w.as_ref().map(|t| AdamTensorState::new(t.len())),
1365 pre_norm_b: b.pre_norm_b.as_ref().map(|t| AdamTensorState::new(t.len())),
1366 attn_norm_w: AdamTensorState::new(b.attn_norm_w.len()),
1367 attn_norm_b: AdamTensorState::new(b.attn_norm_b.len()),
1368 ffn_norm_w: AdamTensorState::new(b.ffn_norm_w.len()),
1369 ffn_norm_b: AdamTensorState::new(b.ffn_norm_b.len()),
1370 attn: AttentionAdamState {
1371 x_r: AdamTensorState::new(b.attn.x_r.len()),
1372 x_w: AdamTensorState::new(b.attn.x_w.len()),
1373 x_k: AdamTensorState::new(b.attn.x_k.len()),
1374 x_v: AdamTensorState::new(b.attn.x_v.len()),
1375 x_a: AdamTensorState::new(b.attn.x_a.len()),
1376 x_g: AdamTensorState::new(b.attn.x_g.len()),
1377 rkv_proj: AdamTensorState::new(b.attn.rkv_proj.len()),
1378 o_proj: AdamTensorState::new(b.attn.o_proj.len()),
1379 w1: AdamTensorState::new(b.attn.w1.len()),
1380 w2: AdamTensorState::new(b.attn.w2.len()),
1381 w0: AdamTensorState::new(b.attn.w0.len()),
1382 a1: AdamTensorState::new(b.attn.a1.len()),
1383 a2: AdamTensorState::new(b.attn.a2.len()),
1384 a0: AdamTensorState::new(b.attn.a0.len()),
1385 v1: b.attn.v1.as_ref().map(|t| AdamTensorState::new(t.len())),
1386 v2: b.attn.v2.as_ref().map(|t| AdamTensorState::new(t.len())),
1387 v0: b.attn.v0.as_ref().map(|t| AdamTensorState::new(t.len())),
1388 g1: AdamTensorState::new(b.attn.g1.len()),
1389 g2: AdamTensorState::new(b.attn.g2.len()),
1390 k_k: AdamTensorState::new(b.attn.k_k.len()),
1391 k_a: AdamTensorState::new(b.attn.k_a.len()),
1392 r_k: AdamTensorState::new(b.attn.r_k.len()),
1393 g_norm_w: AdamTensorState::new(b.attn.g_norm_w.len()),
1394 g_norm_b: AdamTensorState::new(b.attn.g_norm_b.len()),
1395 },
1396 ffn: FfnAdamState {
1397 x_k: AdamTensorState::new(b.ffn.x_k.len()),
1398 key_w: AdamTensorState::new(b.ffn.key_w.len()),
1399 value_w: AdamTensorState::new(b.ffn.value_w.len()),
1400 },
1401 });
1402 }
1403 FullAdamState {
1404 embeddings: AdamTensorState::new(self.embeddings.len()),
1405 ln_out_w: AdamTensorState::new(self.ln_out_w.len()),
1406 ln_out_b: AdamTensorState::new(self.ln_out_b.len()),
1407 lm_head: AdamTensorState::new(self.lm_head.len()),
1408 blocks,
1409 }
1410 }
1411
1412 fn new_full_grad_state(&self) -> FullGradState {
1414 let mut blocks = Vec::with_capacity(self.blocks.len());
1415 for b in &self.blocks {
1416 blocks.push(BlockGradState {
1417 pre_norm_w: b.pre_norm_w.as_ref().map(|t| Tensor1D::zeros(t.len())),
1418 pre_norm_b: b.pre_norm_b.as_ref().map(|t| Tensor1D::zeros(t.len())),
1419 attn_norm_w: Tensor1D::zeros(b.attn_norm_w.len()),
1420 attn_norm_b: Tensor1D::zeros(b.attn_norm_b.len()),
1421 ffn_norm_w: Tensor1D::zeros(b.ffn_norm_w.len()),
1422 ffn_norm_b: Tensor1D::zeros(b.ffn_norm_b.len()),
1423 attn: AttentionGradState {
1424 x_r: Tensor1D::zeros(b.attn.x_r.len()),
1425 x_w: Tensor1D::zeros(b.attn.x_w.len()),
1426 x_k: Tensor1D::zeros(b.attn.x_k.len()),
1427 x_v: Tensor1D::zeros(b.attn.x_v.len()),
1428 x_a: Tensor1D::zeros(b.attn.x_a.len()),
1429 x_g: Tensor1D::zeros(b.attn.x_g.len()),
1430 rkv_proj: Tensor1D::zeros(b.attn.rkv_proj.len()),
1431 o_proj: Tensor1D::zeros(b.attn.o_proj.len()),
1432 w1: Tensor1D::zeros(b.attn.w1.len()),
1433 w2: Tensor1D::zeros(b.attn.w2.len()),
1434 w0: Tensor1D::zeros(b.attn.w0.len()),
1435 a1: Tensor1D::zeros(b.attn.a1.len()),
1436 a2: Tensor1D::zeros(b.attn.a2.len()),
1437 a0: Tensor1D::zeros(b.attn.a0.len()),
1438 v1: b.attn.v1.as_ref().map(|t| Tensor1D::zeros(t.len())),
1439 v2: b.attn.v2.as_ref().map(|t| Tensor1D::zeros(t.len())),
1440 v0: b.attn.v0.as_ref().map(|t| Tensor1D::zeros(t.len())),
1441 g1: Tensor1D::zeros(b.attn.g1.len()),
1442 g2: Tensor1D::zeros(b.attn.g2.len()),
1443 k_k: Tensor1D::zeros(b.attn.k_k.len()),
1444 k_a: Tensor1D::zeros(b.attn.k_a.len()),
1445 r_k: Tensor1D::zeros(b.attn.r_k.len()),
1446 g_norm_w: Tensor1D::zeros(b.attn.g_norm_w.len()),
1447 g_norm_b: Tensor1D::zeros(b.attn.g_norm_b.len()),
1448 },
1449 ffn: FfnGradState {
1450 x_k: Tensor1D::zeros(b.ffn.x_k.len()),
1451 key_w: Tensor1D::zeros(b.ffn.key_w.len()),
1452 value_w: Tensor1D::zeros(b.ffn.value_w.len()),
1453 },
1454 });
1455 }
1456 FullGradState {
1457 embeddings: Tensor1D::zeros(self.embeddings.len()),
1458 ln_out_w: Tensor1D::zeros(self.ln_out_w.len()),
1459 ln_out_b: Tensor1D::zeros(self.ln_out_b.len()),
1460 lm_head: Tensor1D::zeros(self.lm_head.len()),
1461 blocks,
1462 }
1463 }
1464
1465 fn new_recurrent_grad_state(&self) -> RecurrentGradState {
1466 RecurrentGradState::new(&self.cfg)
1467 }
1468
1469 pub fn save_full_adam_safetensors<P: AsRef<Path>>(
1471 &self,
1472 adam: &FullAdamState,
1473 path: P,
1474 ) -> Result<()> {
1475 #[derive(Clone)]
1476 struct TensorRec {
1477 name: String,
1478 shape: Vec<usize>,
1479 data: Vec<f32>,
1480 }
1481 let c = self.cfg.hidden_size;
1482 let i = self.cfg.intermediate_size;
1483 let v = self.cfg.vocab_size;
1484 let h = self.cfg.num_heads;
1485 let n = self.cfg.head_dim;
1486 let d_w = self.cfg.decay_low_rank;
1487 let d_a = self.cfg.a_low_rank;
1488 let d_v = self.cfg.v_low_rank;
1489 let d_g = self.cfg.g_low_rank;
1490 let mut recs = Vec::<TensorRec>::new();
1491 let mut push_state = |name: &str, shape: Vec<usize>, st: &AdamTensorState| {
1492 recs.push(TensorRec {
1493 name: format!("{name}.m"),
1494 shape: shape.clone(),
1495 data: st.m.as_slice().to_vec(),
1496 });
1497 recs.push(TensorRec {
1498 name: format!("{name}.v"),
1499 shape,
1500 data: st.v.as_slice().to_vec(),
1501 });
1502 };
1503
1504 push_state("opt.model.embeddings.weight", vec![v, c], &adam.embeddings);
1505 push_state("opt.model.norm.weight", vec![c], &adam.ln_out_w);
1506 push_state("opt.model.norm.bias", vec![c], &adam.ln_out_b);
1507 push_state("opt.lm_head.weight", vec![v, c], &adam.lm_head);
1508 for (idx, b) in adam.blocks.iter().enumerate() {
1509 let p = format!("opt.model.layers.{idx}");
1510 if let Some(st) = &b.pre_norm_w {
1511 push_state(&format!("{p}.pre_norm.weight"), vec![c], st);
1512 }
1513 if let Some(st) = &b.pre_norm_b {
1514 push_state(&format!("{p}.pre_norm.bias"), vec![c], st);
1515 }
1516 push_state(&format!("{p}.attn_norm.weight"), vec![c], &b.attn_norm_w);
1517 push_state(&format!("{p}.attn_norm.bias"), vec![c], &b.attn_norm_b);
1518 push_state(&format!("{p}.ffn_norm.weight"), vec![c], &b.ffn_norm_w);
1519 push_state(&format!("{p}.ffn_norm.bias"), vec![c], &b.ffn_norm_b);
1520
1521 push_state(&format!("{p}.attn.x_r"), vec![c], &b.attn.x_r);
1522 push_state(&format!("{p}.attn.x_w"), vec![c], &b.attn.x_w);
1523 push_state(&format!("{p}.attn.x_k"), vec![c], &b.attn.x_k);
1524 push_state(&format!("{p}.attn.x_v"), vec![c], &b.attn.x_v);
1525 push_state(&format!("{p}.attn.x_a"), vec![c], &b.attn.x_a);
1526 push_state(&format!("{p}.attn.x_g"), vec![c], &b.attn.x_g);
1527 push_state(
1528 &format!("{p}.attn.rkv_proj"),
1529 vec![3, c, c],
1530 &b.attn.rkv_proj,
1531 );
1532 push_state(
1533 &format!("{p}.attn.o_proj.weight"),
1534 vec![c, c],
1535 &b.attn.o_proj,
1536 );
1537 push_state(
1538 &format!("{p}.attn.w_lora.lora.0.weight"),
1539 vec![d_w, c],
1540 &b.attn.w1,
1541 );
1542 push_state(
1543 &format!("{p}.attn.w_lora.lora.2.weight"),
1544 vec![c, d_w],
1545 &b.attn.w2,
1546 );
1547 push_state(&format!("{p}.attn.w_lora.lora.2.bias"), vec![c], &b.attn.w0);
1548 push_state(
1549 &format!("{p}.attn.a_lora.lora.0.weight"),
1550 vec![d_a, c],
1551 &b.attn.a1,
1552 );
1553 push_state(
1554 &format!("{p}.attn.a_lora.lora.2.weight"),
1555 vec![c, d_a],
1556 &b.attn.a2,
1557 );
1558 push_state(&format!("{p}.attn.a_lora.lora.2.bias"), vec![c], &b.attn.a0);
1559 if let Some(st) = &b.attn.v1 {
1560 push_state(&format!("{p}.attn.v_lora.lora.0.weight"), vec![d_v, c], st);
1561 }
1562 if let Some(st) = &b.attn.v2 {
1563 push_state(&format!("{p}.attn.v_lora.lora.2.weight"), vec![c, d_v], st);
1564 }
1565 if let Some(st) = &b.attn.v0 {
1566 push_state(&format!("{p}.attn.v_lora.lora.2.bias"), vec![c], st);
1567 }
1568 push_state(
1569 &format!("{p}.attn.g_lora.lora.0.weight"),
1570 vec![d_g, c],
1571 &b.attn.g1,
1572 );
1573 push_state(
1574 &format!("{p}.attn.g_lora.lora.2.weight"),
1575 vec![c, d_g],
1576 &b.attn.g2,
1577 );
1578 push_state(&format!("{p}.attn.k_k"), vec![c], &b.attn.k_k);
1579 push_state(&format!("{p}.attn.k_a"), vec![c], &b.attn.k_a);
1580 push_state(&format!("{p}.attn.r_k"), vec![h, n], &b.attn.r_k);
1581 push_state(
1582 &format!("{p}.attn.g_norm.weight"),
1583 vec![c],
1584 &b.attn.g_norm_w,
1585 );
1586 push_state(&format!("{p}.attn.g_norm.bias"), vec![c], &b.attn.g_norm_b);
1587
1588 push_state(&format!("{p}.ffn.x_k"), vec![c], &b.ffn.x_k);
1589 push_state(&format!("{p}.ffn.key.weight"), vec![i, c], &b.ffn.key_w);
1590 push_state(&format!("{p}.ffn.value.weight"), vec![c, i], &b.ffn.value_w);
1591 }
1592
1593 recs.sort_by(|a, b| a.name.cmp(&b.name));
1594 let mut offset = 0usize;
1595 let mut header = serde_json::Map::new();
1596 header.insert("__metadata__".to_string(), json!({}));
1597 for rec in &recs {
1598 let bytes = rec.data.len() * 4;
1599 header.insert(
1600 rec.name.clone(),
1601 json!({
1602 "dtype": "F32",
1603 "shape": rec.shape,
1604 "data_offsets": [offset, offset + bytes],
1605 }),
1606 );
1607 offset += bytes;
1608 }
1609
1610 let header_bytes = serde_json::to_vec(&header)?;
1611 let mut f = File::create(path)?;
1612 f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
1613 f.write_all(&header_bytes)?;
1614 for rec in &recs {
1615 for v in &rec.data {
1616 f.write_all(&v.to_le_bytes())?;
1617 }
1618 }
1619 Ok(())
1620 }
1621
1622 pub fn load_full_adam_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<FullAdamState> {
1624 let weights = Weights::load(path.as_ref()).with_context(|| {
1625 format!(
1626 "failed to load optimizer moments from {}",
1627 path.as_ref().display()
1628 )
1629 })?;
1630 let mut adam = self.new_full_adam_state();
1631 let load_state = |name: &str, st: &mut AdamTensorState| -> Result<()> {
1632 let m_name = format!("{name}.m");
1633 let v_name = format!("{name}.v");
1634 let m_t = weights
1635 .require(&m_name)
1636 .with_context(|| format!("missing optimizer tensor '{m_name}'"))?;
1637 let v_t = weights
1638 .require(&v_name)
1639 .with_context(|| format!("missing optimizer tensor '{v_name}'"))?;
1640 if m_t.data().len() != st.m.len() {
1641 bail!(
1642 "optimizer tensor '{}' len {} != expected {}",
1643 m_name,
1644 m_t.data().len(),
1645 st.m.len()
1646 );
1647 }
1648 if v_t.data().len() != st.v.len() {
1649 bail!(
1650 "optimizer tensor '{}' len {} != expected {}",
1651 v_name,
1652 v_t.data().len(),
1653 st.v.len()
1654 );
1655 }
1656 st.m.as_mut_slice().copy_from_slice(m_t.data());
1657 st.v.as_mut_slice().copy_from_slice(v_t.data());
1658 Ok(())
1659 };
1660
1661 let c = self.cfg.hidden_size;
1662 let i = self.cfg.intermediate_size;
1663 let v = self.cfg.vocab_size;
1664 let h = self.cfg.num_heads;
1665 let n = self.cfg.head_dim;
1666 let _ = (c, i, v, h, n);
1667 load_state("opt.model.embeddings.weight", &mut adam.embeddings)?;
1668 load_state("opt.model.norm.weight", &mut adam.ln_out_w)?;
1669 load_state("opt.model.norm.bias", &mut adam.ln_out_b)?;
1670 load_state("opt.lm_head.weight", &mut adam.lm_head)?;
1671 for (idx, b) in adam.blocks.iter_mut().enumerate() {
1672 let p = format!("opt.model.layers.{idx}");
1673 if let Some(st) = b.pre_norm_w.as_mut() {
1674 load_state(&format!("{p}.pre_norm.weight"), st)?;
1675 }
1676 if let Some(st) = b.pre_norm_b.as_mut() {
1677 load_state(&format!("{p}.pre_norm.bias"), st)?;
1678 }
1679 load_state(&format!("{p}.attn_norm.weight"), &mut b.attn_norm_w)?;
1680 load_state(&format!("{p}.attn_norm.bias"), &mut b.attn_norm_b)?;
1681 load_state(&format!("{p}.ffn_norm.weight"), &mut b.ffn_norm_w)?;
1682 load_state(&format!("{p}.ffn_norm.bias"), &mut b.ffn_norm_b)?;
1683 load_state(&format!("{p}.attn.x_r"), &mut b.attn.x_r)?;
1684 load_state(&format!("{p}.attn.x_w"), &mut b.attn.x_w)?;
1685 load_state(&format!("{p}.attn.x_k"), &mut b.attn.x_k)?;
1686 load_state(&format!("{p}.attn.x_v"), &mut b.attn.x_v)?;
1687 load_state(&format!("{p}.attn.x_a"), &mut b.attn.x_a)?;
1688 load_state(&format!("{p}.attn.x_g"), &mut b.attn.x_g)?;
1689 load_state(&format!("{p}.attn.rkv_proj"), &mut b.attn.rkv_proj)?;
1690 load_state(&format!("{p}.attn.o_proj.weight"), &mut b.attn.o_proj)?;
1691 load_state(&format!("{p}.attn.w_lora.lora.0.weight"), &mut b.attn.w1)?;
1692 load_state(&format!("{p}.attn.w_lora.lora.2.weight"), &mut b.attn.w2)?;
1693 load_state(&format!("{p}.attn.w_lora.lora.2.bias"), &mut b.attn.w0)?;
1694 load_state(&format!("{p}.attn.a_lora.lora.0.weight"), &mut b.attn.a1)?;
1695 load_state(&format!("{p}.attn.a_lora.lora.2.weight"), &mut b.attn.a2)?;
1696 load_state(&format!("{p}.attn.a_lora.lora.2.bias"), &mut b.attn.a0)?;
1697 if let Some(st) = b.attn.v1.as_mut() {
1698 load_state(&format!("{p}.attn.v_lora.lora.0.weight"), st)?;
1699 }
1700 if let Some(st) = b.attn.v2.as_mut() {
1701 load_state(&format!("{p}.attn.v_lora.lora.2.weight"), st)?;
1702 }
1703 if let Some(st) = b.attn.v0.as_mut() {
1704 load_state(&format!("{p}.attn.v_lora.lora.2.bias"), st)?;
1705 }
1706 load_state(&format!("{p}.attn.g_lora.lora.0.weight"), &mut b.attn.g1)?;
1707 load_state(&format!("{p}.attn.g_lora.lora.2.weight"), &mut b.attn.g2)?;
1708 load_state(&format!("{p}.attn.k_k"), &mut b.attn.k_k)?;
1709 load_state(&format!("{p}.attn.k_a"), &mut b.attn.k_a)?;
1710 load_state(&format!("{p}.attn.r_k"), &mut b.attn.r_k)?;
1711 load_state(&format!("{p}.attn.g_norm.weight"), &mut b.attn.g_norm_w)?;
1712 load_state(&format!("{p}.attn.g_norm.bias"), &mut b.attn.g_norm_b)?;
1713 load_state(&format!("{p}.ffn.x_k"), &mut b.ffn.x_k)?;
1714 load_state(&format!("{p}.ffn.key.weight"), &mut b.ffn.key_w)?;
1715 load_state(&format!("{p}.ffn.value.weight"), &mut b.ffn.value_w)?;
1716 }
1717 Ok(adam)
1718 }
1719
1720 pub fn config(&self) -> &Config {
1722 &self.cfg
1723 }
1724
1725 pub fn new_state(&self) -> State {
1727 State::new(&self.cfg)
1728 }
1729
1730 #[inline]
1732 pub fn lm_head_weights(&self) -> &[f32] {
1733 self.lm_head.as_slice()
1734 }
1735
1736 #[inline]
1738 pub fn lm_head_weights_mut(&mut self) -> &mut [f32] {
1739 self.lm_head.as_mut_slice()
1740 }
1741
1742 #[allow(clippy::too_many_arguments)]
1743 fn apply_full_gradients(
1744 &mut self,
1745 grads: &FullGradState,
1746 scope: TrainScopeMask,
1747 optimizer: OptimizerKind,
1748 lr: f32,
1749 clip: f32,
1750 adam_t: &mut usize,
1751 model_adam: Option<&mut FullAdamState>,
1752 out_bias: Option<&mut [f32]>,
1753 out_bias_grad: Option<&[f32]>,
1754 out_bias_adam_m: Option<&mut [f32]>,
1755 out_bias_adam_v: Option<&mut [f32]>,
1756 ) -> Result<()> {
1757 let mut adam_step = None::<AdamStep>;
1758 let mut model_adam = model_adam;
1759 if matches!(optimizer, OptimizerKind::Adam) {
1760 *adam_t = adam_t.saturating_add(1);
1761 let t = (*adam_t).max(1) as i32;
1762 let b1 = 0.9f32;
1763 let b2 = 0.999f32;
1764 adam_step = Some(AdamStep {
1765 lr,
1766 clip: clip.max(0.0),
1767 b1,
1768 b2,
1769 eps: 1e-8,
1770 bias_corr1: 1.0 - b1.powi(t),
1771 bias_corr2: 1.0 - b2.powi(t),
1772 });
1773 if scope.trains_non_head_params() && model_adam.is_none() {
1774 bail!("rwkv Adam full-training state is missing");
1775 }
1776 }
1777
1778 if scope.bias
1779 && let (Some(bias), Some(grad)) = (out_bias, out_bias_grad)
1780 {
1781 match optimizer {
1782 OptimizerKind::Sgd => sgd_vec_update(bias, grad, lr, clip),
1783 OptimizerKind::Adam => {
1784 let cfg = adam_step.as_ref().expect("adam cfg initialized");
1785 let Some(m) = out_bias_adam_m else {
1786 bail!("rwkv Adam output-bias state is missing (m)");
1787 };
1788 let Some(v) = out_bias_adam_v else {
1789 bail!("rwkv Adam output-bias state is missing (v)");
1790 };
1791 apply_adam_vec_update_raw(bias, grad, m, v, cfg);
1792 }
1793 }
1794 }
1795
1796 if scope.head {
1797 match optimizer {
1798 OptimizerKind::Sgd => {
1799 sgd_vec_update(
1800 self.lm_head.as_mut_slice(),
1801 grads.lm_head.as_slice(),
1802 lr,
1803 clip,
1804 );
1805 sgd_vec_update(
1806 self.ln_out_w.as_mut_slice(),
1807 grads.ln_out_w.as_slice(),
1808 lr,
1809 clip,
1810 );
1811 sgd_vec_update(
1812 self.ln_out_b.as_mut_slice(),
1813 grads.ln_out_b.as_slice(),
1814 lr,
1815 clip,
1816 );
1817 }
1818 OptimizerKind::Adam => {
1819 let cfg = adam_step.as_ref().expect("adam cfg initialized");
1820 let adam = model_adam.as_mut().expect("adam state exists");
1821 apply_adam_vec_update(
1822 self.lm_head.as_mut_slice(),
1823 grads.lm_head.as_slice(),
1824 &mut adam.lm_head,
1825 cfg,
1826 );
1827 apply_adam_vec_update(
1828 self.ln_out_w.as_mut_slice(),
1829 grads.ln_out_w.as_slice(),
1830 &mut adam.ln_out_w,
1831 cfg,
1832 );
1833 apply_adam_vec_update(
1834 self.ln_out_b.as_mut_slice(),
1835 grads.ln_out_b.as_slice(),
1836 &mut adam.ln_out_b,
1837 cfg,
1838 );
1839 }
1840 }
1841 }
1842
1843 for layer_idx in 0..self.cfg.num_layers {
1844 let block = &mut self.blocks[layer_idx];
1845 let grad = &grads.blocks[layer_idx];
1846 match optimizer {
1847 OptimizerKind::Sgd => {
1848 if scope.ffn {
1849 sgd_vec_update(
1850 block.ffn.x_k.as_mut_slice(),
1851 grad.ffn.x_k.as_slice(),
1852 lr,
1853 clip,
1854 );
1855 sgd_vec_update(
1856 block.ffn.key_w.as_mut_slice(),
1857 grad.ffn.key_w.as_slice(),
1858 lr,
1859 clip,
1860 );
1861 sgd_vec_update(
1862 block.ffn.value_w.as_mut_slice(),
1863 grad.ffn.value_w.as_slice(),
1864 lr,
1865 clip,
1866 );
1867 }
1868 if scope.ffn_norm {
1869 sgd_vec_update(
1870 block.ffn_norm_w.as_mut_slice(),
1871 grad.ffn_norm_w.as_slice(),
1872 lr,
1873 clip,
1874 );
1875 sgd_vec_update(
1876 block.ffn_norm_b.as_mut_slice(),
1877 grad.ffn_norm_b.as_slice(),
1878 lr,
1879 clip,
1880 );
1881 }
1882 if scope.attn {
1883 sgd_vec_update(
1884 block.attn.o_proj.as_mut_slice(),
1885 grad.attn.o_proj.as_slice(),
1886 lr,
1887 clip,
1888 );
1889 sgd_vec_update(
1890 block.attn.r_k.as_mut_slice(),
1891 grad.attn.r_k.as_slice(),
1892 lr,
1893 clip,
1894 );
1895 sgd_vec_update(
1896 block.attn.g_norm_w.as_mut_slice(),
1897 grad.attn.g_norm_w.as_slice(),
1898 lr,
1899 clip,
1900 );
1901 sgd_vec_update(
1902 block.attn.g_norm_b.as_mut_slice(),
1903 grad.attn.g_norm_b.as_slice(),
1904 lr,
1905 clip,
1906 );
1907 sgd_vec_update(
1908 block.attn.k_a.as_mut_slice(),
1909 grad.attn.k_a.as_slice(),
1910 lr,
1911 clip,
1912 );
1913 sgd_vec_update(
1914 block.attn.k_k.as_mut_slice(),
1915 grad.attn.k_k.as_slice(),
1916 lr,
1917 clip,
1918 );
1919 sgd_vec_update(
1920 block.attn.rkv_proj.as_mut_slice(),
1921 grad.attn.rkv_proj.as_slice(),
1922 lr,
1923 clip,
1924 );
1925 sgd_vec_update(
1926 block.attn.w0.as_mut_slice(),
1927 grad.attn.w0.as_slice(),
1928 lr,
1929 clip,
1930 );
1931 sgd_vec_update(
1932 block.attn.w2.as_mut_slice(),
1933 grad.attn.w2.as_slice(),
1934 lr,
1935 clip,
1936 );
1937 sgd_vec_update(
1938 block.attn.w1.as_mut_slice(),
1939 grad.attn.w1.as_slice(),
1940 lr,
1941 clip,
1942 );
1943 sgd_vec_update(
1944 block.attn.a0.as_mut_slice(),
1945 grad.attn.a0.as_slice(),
1946 lr,
1947 clip,
1948 );
1949 sgd_vec_update(
1950 block.attn.a2.as_mut_slice(),
1951 grad.attn.a2.as_slice(),
1952 lr,
1953 clip,
1954 );
1955 sgd_vec_update(
1956 block.attn.a1.as_mut_slice(),
1957 grad.attn.a1.as_slice(),
1958 lr,
1959 clip,
1960 );
1961 sgd_vec_update(
1962 block.attn.g2.as_mut_slice(),
1963 grad.attn.g2.as_slice(),
1964 lr,
1965 clip,
1966 );
1967 sgd_vec_update(
1968 block.attn.g1.as_mut_slice(),
1969 grad.attn.g1.as_slice(),
1970 lr,
1971 clip,
1972 );
1973 sgd_vec_update(
1974 block.attn.x_r.as_mut_slice(),
1975 grad.attn.x_r.as_slice(),
1976 lr,
1977 clip,
1978 );
1979 sgd_vec_update(
1980 block.attn.x_w.as_mut_slice(),
1981 grad.attn.x_w.as_slice(),
1982 lr,
1983 clip,
1984 );
1985 sgd_vec_update(
1986 block.attn.x_k.as_mut_slice(),
1987 grad.attn.x_k.as_slice(),
1988 lr,
1989 clip,
1990 );
1991 sgd_vec_update(
1992 block.attn.x_v.as_mut_slice(),
1993 grad.attn.x_v.as_slice(),
1994 lr,
1995 clip,
1996 );
1997 sgd_vec_update(
1998 block.attn.x_a.as_mut_slice(),
1999 grad.attn.x_a.as_slice(),
2000 lr,
2001 clip,
2002 );
2003 sgd_vec_update(
2004 block.attn.x_g.as_mut_slice(),
2005 grad.attn.x_g.as_slice(),
2006 lr,
2007 clip,
2008 );
2009 if let (Some(v1), Some(gv1)) =
2010 (block.attn.v1.as_mut(), grad.attn.v1.as_ref())
2011 {
2012 sgd_vec_update(v1.as_mut_slice(), gv1.as_slice(), lr, clip);
2013 }
2014 if let (Some(v2), Some(gv2)) =
2015 (block.attn.v2.as_mut(), grad.attn.v2.as_ref())
2016 {
2017 sgd_vec_update(v2.as_mut_slice(), gv2.as_slice(), lr, clip);
2018 }
2019 if let (Some(v0), Some(gv0)) =
2020 (block.attn.v0.as_mut(), grad.attn.v0.as_ref())
2021 {
2022 sgd_vec_update(v0.as_mut_slice(), gv0.as_slice(), lr, clip);
2023 }
2024 }
2025 if scope.attn_norm {
2026 sgd_vec_update(
2027 block.attn_norm_w.as_mut_slice(),
2028 grad.attn_norm_w.as_slice(),
2029 lr,
2030 clip,
2031 );
2032 sgd_vec_update(
2033 block.attn_norm_b.as_mut_slice(),
2034 grad.attn_norm_b.as_slice(),
2035 lr,
2036 clip,
2037 );
2038 }
2039 if scope.pre_norm
2040 && let (Some(w), Some(gw)) =
2041 (block.pre_norm_w.as_mut(), grad.pre_norm_w.as_ref())
2042 {
2043 sgd_vec_update(w.as_mut_slice(), gw.as_slice(), lr, clip);
2044 }
2045 if scope.pre_norm
2046 && let (Some(b), Some(gb)) =
2047 (block.pre_norm_b.as_mut(), grad.pre_norm_b.as_ref())
2048 {
2049 sgd_vec_update(b.as_mut_slice(), gb.as_slice(), lr, clip);
2050 }
2051 }
2052 OptimizerKind::Adam => {
2053 let cfg = adam_step.as_ref().expect("adam cfg initialized");
2054 let adam =
2055 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
2056 if scope.ffn {
2057 apply_adam_vec_update(
2058 block.ffn.x_k.as_mut_slice(),
2059 grad.ffn.x_k.as_slice(),
2060 &mut adam.ffn.x_k,
2061 cfg,
2062 );
2063 apply_adam_vec_update(
2064 block.ffn.key_w.as_mut_slice(),
2065 grad.ffn.key_w.as_slice(),
2066 &mut adam.ffn.key_w,
2067 cfg,
2068 );
2069 apply_adam_vec_update(
2070 block.ffn.value_w.as_mut_slice(),
2071 grad.ffn.value_w.as_slice(),
2072 &mut adam.ffn.value_w,
2073 cfg,
2074 );
2075 }
2076 if scope.ffn_norm {
2077 apply_adam_vec_update(
2078 block.ffn_norm_w.as_mut_slice(),
2079 grad.ffn_norm_w.as_slice(),
2080 &mut adam.ffn_norm_w,
2081 cfg,
2082 );
2083 apply_adam_vec_update(
2084 block.ffn_norm_b.as_mut_slice(),
2085 grad.ffn_norm_b.as_slice(),
2086 &mut adam.ffn_norm_b,
2087 cfg,
2088 );
2089 }
2090 if scope.attn {
2091 apply_adam_vec_update(
2092 block.attn.o_proj.as_mut_slice(),
2093 grad.attn.o_proj.as_slice(),
2094 &mut adam.attn.o_proj,
2095 cfg,
2096 );
2097 apply_adam_vec_update(
2098 block.attn.r_k.as_mut_slice(),
2099 grad.attn.r_k.as_slice(),
2100 &mut adam.attn.r_k,
2101 cfg,
2102 );
2103 apply_adam_vec_update(
2104 block.attn.g_norm_w.as_mut_slice(),
2105 grad.attn.g_norm_w.as_slice(),
2106 &mut adam.attn.g_norm_w,
2107 cfg,
2108 );
2109 apply_adam_vec_update(
2110 block.attn.g_norm_b.as_mut_slice(),
2111 grad.attn.g_norm_b.as_slice(),
2112 &mut adam.attn.g_norm_b,
2113 cfg,
2114 );
2115 apply_adam_vec_update(
2116 block.attn.k_a.as_mut_slice(),
2117 grad.attn.k_a.as_slice(),
2118 &mut adam.attn.k_a,
2119 cfg,
2120 );
2121 apply_adam_vec_update(
2122 block.attn.k_k.as_mut_slice(),
2123 grad.attn.k_k.as_slice(),
2124 &mut adam.attn.k_k,
2125 cfg,
2126 );
2127 apply_adam_vec_update(
2128 block.attn.rkv_proj.as_mut_slice(),
2129 grad.attn.rkv_proj.as_slice(),
2130 &mut adam.attn.rkv_proj,
2131 cfg,
2132 );
2133 apply_adam_vec_update(
2134 block.attn.w0.as_mut_slice(),
2135 grad.attn.w0.as_slice(),
2136 &mut adam.attn.w0,
2137 cfg,
2138 );
2139 apply_adam_vec_update(
2140 block.attn.w2.as_mut_slice(),
2141 grad.attn.w2.as_slice(),
2142 &mut adam.attn.w2,
2143 cfg,
2144 );
2145 apply_adam_vec_update(
2146 block.attn.w1.as_mut_slice(),
2147 grad.attn.w1.as_slice(),
2148 &mut adam.attn.w1,
2149 cfg,
2150 );
2151 apply_adam_vec_update(
2152 block.attn.a0.as_mut_slice(),
2153 grad.attn.a0.as_slice(),
2154 &mut adam.attn.a0,
2155 cfg,
2156 );
2157 apply_adam_vec_update(
2158 block.attn.a2.as_mut_slice(),
2159 grad.attn.a2.as_slice(),
2160 &mut adam.attn.a2,
2161 cfg,
2162 );
2163 apply_adam_vec_update(
2164 block.attn.a1.as_mut_slice(),
2165 grad.attn.a1.as_slice(),
2166 &mut adam.attn.a1,
2167 cfg,
2168 );
2169 apply_adam_vec_update(
2170 block.attn.g2.as_mut_slice(),
2171 grad.attn.g2.as_slice(),
2172 &mut adam.attn.g2,
2173 cfg,
2174 );
2175 apply_adam_vec_update(
2176 block.attn.g1.as_mut_slice(),
2177 grad.attn.g1.as_slice(),
2178 &mut adam.attn.g1,
2179 cfg,
2180 );
2181 apply_adam_vec_update(
2182 block.attn.x_r.as_mut_slice(),
2183 grad.attn.x_r.as_slice(),
2184 &mut adam.attn.x_r,
2185 cfg,
2186 );
2187 apply_adam_vec_update(
2188 block.attn.x_w.as_mut_slice(),
2189 grad.attn.x_w.as_slice(),
2190 &mut adam.attn.x_w,
2191 cfg,
2192 );
2193 apply_adam_vec_update(
2194 block.attn.x_k.as_mut_slice(),
2195 grad.attn.x_k.as_slice(),
2196 &mut adam.attn.x_k,
2197 cfg,
2198 );
2199 apply_adam_vec_update(
2200 block.attn.x_v.as_mut_slice(),
2201 grad.attn.x_v.as_slice(),
2202 &mut adam.attn.x_v,
2203 cfg,
2204 );
2205 apply_adam_vec_update(
2206 block.attn.x_a.as_mut_slice(),
2207 grad.attn.x_a.as_slice(),
2208 &mut adam.attn.x_a,
2209 cfg,
2210 );
2211 apply_adam_vec_update(
2212 block.attn.x_g.as_mut_slice(),
2213 grad.attn.x_g.as_slice(),
2214 &mut adam.attn.x_g,
2215 cfg,
2216 );
2217 if let (Some(v1), Some(gv1), Some(av1)) = (
2218 block.attn.v1.as_mut(),
2219 grad.attn.v1.as_ref(),
2220 adam.attn.v1.as_mut(),
2221 ) {
2222 apply_adam_vec_update(v1.as_mut_slice(), gv1.as_slice(), av1, cfg);
2223 }
2224 if let (Some(v2), Some(gv2), Some(av2)) = (
2225 block.attn.v2.as_mut(),
2226 grad.attn.v2.as_ref(),
2227 adam.attn.v2.as_mut(),
2228 ) {
2229 apply_adam_vec_update(v2.as_mut_slice(), gv2.as_slice(), av2, cfg);
2230 }
2231 if let (Some(v0), Some(gv0), Some(av0)) = (
2232 block.attn.v0.as_mut(),
2233 grad.attn.v0.as_ref(),
2234 adam.attn.v0.as_mut(),
2235 ) {
2236 apply_adam_vec_update(v0.as_mut_slice(), gv0.as_slice(), av0, cfg);
2237 }
2238 }
2239 if scope.attn_norm {
2240 apply_adam_vec_update(
2241 block.attn_norm_w.as_mut_slice(),
2242 grad.attn_norm_w.as_slice(),
2243 &mut adam.attn_norm_w,
2244 cfg,
2245 );
2246 apply_adam_vec_update(
2247 block.attn_norm_b.as_mut_slice(),
2248 grad.attn_norm_b.as_slice(),
2249 &mut adam.attn_norm_b,
2250 cfg,
2251 );
2252 }
2253 if scope.pre_norm
2254 && let (Some(w), Some(gw), Some(aw)) = (
2255 block.pre_norm_w.as_mut(),
2256 grad.pre_norm_w.as_ref(),
2257 adam.pre_norm_w.as_mut(),
2258 )
2259 {
2260 apply_adam_vec_update(w.as_mut_slice(), gw.as_slice(), aw, cfg);
2261 }
2262 if scope.pre_norm
2263 && let (Some(b), Some(gb), Some(ab)) = (
2264 block.pre_norm_b.as_mut(),
2265 grad.pre_norm_b.as_ref(),
2266 adam.pre_norm_b.as_mut(),
2267 )
2268 {
2269 apply_adam_vec_update(b.as_mut_slice(), gb.as_slice(), ab, cfg);
2270 }
2271 }
2272 }
2273 }
2274
2275 if scope.embed {
2276 match optimizer {
2277 OptimizerKind::Sgd => {
2278 sgd_vec_update(
2279 self.embeddings.as_mut_slice(),
2280 grads.embeddings.as_slice(),
2281 lr,
2282 clip,
2283 );
2284 }
2285 OptimizerKind::Adam => {
2286 let cfg = adam_step.as_ref().expect("adam cfg initialized");
2287 let adam = model_adam.as_mut().expect("adam state exists");
2288 apply_adam_vec_update(
2289 self.embeddings.as_mut_slice(),
2290 grads.embeddings.as_slice(),
2291 &mut adam.embeddings,
2292 cfg,
2293 );
2294 }
2295 }
2296 }
2297 Ok(())
2298 }
2299
2300 #[allow(clippy::needless_range_loop)]
2301 fn accumulate_token_step_gradients(
2302 &self,
2303 scratch: &mut ScratchBuffers,
2304 trace: &TokenTrainTrace,
2305 state_new: &State,
2306 symbol: u8,
2307 pdf: &[f64],
2308 grad_scale: f32,
2309 scope: TrainScopeMask,
2310 grads: &mut FullGradState,
2311 out_bias_grad: Option<&mut [f32]>,
2312 future: &mut RecurrentGradState,
2313 ) -> Result<()> {
2314 let c = self.cfg.hidden_size;
2315 let h = self.cfg.num_heads;
2316 let n = self.cfg.head_dim;
2317 let i = self.cfg.intermediate_size;
2318 let d_w = self.cfg.decay_low_rank;
2319 let d_a = self.cfg.a_low_rank;
2320 let d_v = self.cfg.v_low_rank;
2321 let d_g = self.cfg.g_low_rank;
2322 let vocab = self.cfg.vocab_size.min(pdf.len());
2323 if vocab == 0 {
2324 return Ok(());
2325 }
2326
2327 scratch.grad_logits.zero();
2328 for idx in 0..vocab {
2329 let p = pdf[idx].clamp(1e-12, 1.0) as f32;
2330 let target = if idx == symbol as usize { 1.0 } else { 0.0 };
2331 scratch.grad_logits[idx] = (target - p) * grad_scale;
2332 }
2333
2334 if scope.bias
2335 && let Some(bias_grad) = out_bias_grad
2336 {
2337 add_vec_grad(
2338 &mut bias_grad[0..vocab],
2339 &scratch.grad_logits.as_slice()[0..vocab],
2340 );
2341 }
2342
2343 scratch.grad_x.zero();
2344 if scope.head {
2345 add_outer_grad(
2346 grads.lm_head.as_mut_slice(),
2347 vocab,
2348 c,
2349 &scratch.grad_logits.as_slice()[0..vocab],
2350 trace.x_normed.as_slice(),
2351 );
2352 }
2353 for row in 0..vocab {
2354 let g = scratch.grad_logits[row];
2355 if g == 0.0 {
2356 continue;
2357 }
2358 let row_off = row * c;
2359 for col in 0..c {
2360 scratch.grad_x[col] += self.lm_head[row_off + col] * g;
2361 }
2362 }
2363
2364 let needs_backprop = scope.trains_non_head_params() || scope.head;
2365 if !needs_backprop {
2366 return Ok(());
2367 }
2368
2369 layer_norm_backward(
2370 trace.x.as_slice(),
2371 self.ln_out_w.as_slice(),
2372 scratch.grad_x.as_slice(),
2373 self.cfg.layer_norm_eps,
2374 scratch.grad_x2.as_mut_slice(),
2375 scratch.grad_x3.as_mut_slice(),
2376 scratch.grad_x4.as_mut_slice(),
2377 );
2378 if scope.head {
2379 add_vec_grad(grads.ln_out_w.as_mut_slice(), scratch.grad_x3.as_slice());
2380 add_vec_grad(grads.ln_out_b.as_mut_slice(), scratch.grad_x4.as_slice());
2381 }
2382 scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
2383 scratch.grad_v_first.zero();
2384
2385 for layer_idx in (0..self.cfg.num_layers).rev() {
2386 let tr = &trace.layers[layer_idx];
2387 let block = &self.blocks[layer_idx];
2388 let block_grads = &mut grads.blocks[layer_idx];
2389 let future_layer = &mut future.layers[layer_idx];
2390
2391 scratch.grad_x2.copy_from_slice(scratch.grad_x.as_slice());
2392 scratch.grad_x3.copy_from_slice(scratch.grad_x.as_slice());
2393
2394 unsafe {
2395 kernel::gemv_t_avx(
2396 block.ffn.value_w.as_ptr(),
2397 scratch.grad_x3.as_ptr(),
2398 scratch.grad_ffn.as_mut_ptr(),
2399 c,
2400 i,
2401 );
2402 }
2403 if scope.ffn {
2404 add_outer_grad(
2405 block_grads.ffn.value_w.as_mut_slice(),
2406 c,
2407 i,
2408 scratch.grad_x3.as_slice(),
2409 tr.ffn_k.as_slice(),
2410 );
2411 }
2412
2413 for col in 0..i {
2414 let pre = tr.ffn_pre[col];
2415 scratch.grad_ffn2[col] = if pre > 0.0 {
2416 scratch.grad_ffn[col] * (2.0 * pre)
2417 } else {
2418 0.0
2419 };
2420 }
2421
2422 unsafe {
2423 kernel::gemv_t_avx(
2424 block.ffn.key_w.as_ptr(),
2425 scratch.grad_ffn2.as_ptr(),
2426 scratch.grad_x4.as_mut_ptr(),
2427 i,
2428 c,
2429 );
2430 }
2431 if scope.ffn {
2432 add_outer_grad(
2433 block_grads.ffn.key_w.as_mut_slice(),
2434 i,
2435 c,
2436 scratch.grad_ffn2.as_slice(),
2437 tr.ffn_xk.as_slice(),
2438 );
2439 }
2440
2441 scratch
2442 .grad_x5
2443 .copy_from_slice(future_layer.ffn_x_prev.as_slice());
2444 future_layer.ffn_x_prev.zero();
2445 for col in 0..c {
2446 let g = scratch.grad_x4[col];
2447 let mix = block.ffn.x_k[col];
2448 let base = tr.ffn_norm[col];
2449 let prev = tr.ffn_x_prev_old[col];
2450 scratch.grad_x5[col] += g * (1.0 - mix);
2451 future_layer.ffn_x_prev[col] = g * mix;
2452 scratch.grad_param[col] = g * (prev - base);
2453 }
2454 if scope.ffn {
2455 add_vec_grad(
2456 block_grads.ffn.x_k.as_mut_slice(),
2457 scratch.grad_param.as_slice(),
2458 );
2459 }
2460
2461 layer_norm_backward(
2462 tr.x_after_attn.as_slice(),
2463 block.ffn_norm_w.as_slice(),
2464 scratch.grad_x5.as_slice(),
2465 self.cfg.layer_norm_eps,
2466 scratch.grad_x4.as_mut_slice(),
2467 scratch.grad_x3.as_mut_slice(),
2468 scratch.grad_x6.as_mut_slice(),
2469 );
2470 if scope.ffn_norm {
2471 add_vec_grad(
2472 block_grads.ffn_norm_w.as_mut_slice(),
2473 scratch.grad_x3.as_slice(),
2474 );
2475 add_vec_grad(
2476 block_grads.ffn_norm_b.as_mut_slice(),
2477 scratch.grad_x6.as_slice(),
2478 );
2479 }
2480 for col in 0..c {
2481 scratch.grad_x2[col] += scratch.grad_x4[col];
2482 }
2483
2484 scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
2485 scratch.grad_x3.copy_from_slice(scratch.grad_x2.as_slice());
2486
2487 unsafe {
2488 kernel::gemv_t_avx(
2489 block.attn.o_proj.as_ptr(),
2490 scratch.grad_x3.as_ptr(),
2491 scratch.grad_x4.as_mut_ptr(),
2492 c,
2493 c,
2494 );
2495 }
2496 if scope.attn {
2497 add_outer_grad(
2498 block_grads.attn.o_proj.as_mut_slice(),
2499 c,
2500 c,
2501 scratch.grad_x3.as_slice(),
2502 tr.y_gate.as_slice(),
2503 );
2504 }
2505
2506 for col in 0..c {
2507 let gy = scratch.grad_x4[col];
2508 scratch.grad_saved[col] = gy * tr.y_head[col];
2509 scratch.grad_x4[col] = gy * tr.g[col];
2510 }
2511
2512 scratch.grad_x2.zero();
2513 scratch.grad_x3.zero();
2514 scratch.grad_x6.zero();
2515 scratch.grad_param.zero();
2516 for head_idx in 0..h {
2517 let off = head_idx * n;
2518 let mut g_alpha = 0.0f32;
2519 for j in 0..n {
2520 let g = scratch.grad_x4[off + j];
2521 g_alpha += g * tr.v[off + j];
2522 scratch.grad_x6[off + j] += g * tr.alpha[head_idx];
2523 }
2524 for j in 0..n {
2525 let idx = off + j;
2526 let rk = block.attn.r_k[idx];
2527 let rv = tr.r[idx];
2528 let kv = tr.k[idx];
2529 let g = g_alpha * rk;
2530 scratch.grad_x2[idx] += g * kv;
2531 scratch.grad_x3[idx] += g * rv;
2532 scratch.grad_param[idx] += g_alpha * rv * kv;
2533 }
2534 }
2535 if scope.attn {
2536 add_vec_grad(
2537 block_grads.attn.r_k.as_mut_slice(),
2538 scratch.grad_param.as_slice(),
2539 );
2540 }
2541
2542 scratch.grad_x5.as_mut_slice()[0..c].copy_from_slice(&scratch.grad_x4.as_slice()[0..c]);
2543 group_norm_backward(
2544 tr.y_wkv.as_slice(),
2545 block.attn.g_norm_w.as_slice(),
2546 scratch.grad_x5.as_slice(),
2547 h,
2548 n,
2549 self.cfg.group_norm_eps,
2550 scratch.grad_x4.as_mut_slice(),
2551 scratch.grad_param.as_mut_slice(),
2552 scratch.grad_param2.as_mut_slice(),
2553 );
2554 if scope.attn {
2555 add_vec_grad(
2556 block_grads.attn.g_norm_w.as_mut_slice(),
2557 scratch.grad_param.as_slice(),
2558 );
2559 add_vec_grad(
2560 block_grads.attn.g_norm_b.as_mut_slice(),
2561 scratch.grad_param2.as_slice(),
2562 );
2563 }
2564
2565 scratch.grad_param.zero();
2566 scratch.grad_x5.zero();
2567 scratch.grad_param2.zero();
2568 scratch
2569 .grad_att_state
2570 .copy_from_slice(future_layer.att_state.as_slice());
2571 future_layer.att_state.zero();
2572 let s_old = tr.att_state_old.as_slice();
2573 let s_new = state_new.layers[layer_idx].att_state.as_slice();
2574 for head_idx in 0..h {
2575 let off = head_idx * n;
2576 let s_off = head_idx * n * n;
2577 let grad_y = &scratch.grad_x4.as_slice()[off..off + n];
2578 let r_head = &tr.r.as_slice()[off..off + n];
2579 let k_head = &tr.k.as_slice()[off..off + n];
2580 let kk_head = &tr.kk.as_slice()[off..off + n];
2581 let a_head = &tr.a.as_slice()[off..off + n];
2582 let v_head = &tr.v.as_slice()[off..off + n];
2583 let w_head = &tr.w_decay.as_slice()[off..off + n];
2584
2585 unsafe {
2586 kernel::gemv_t_avx(
2587 s_new.as_ptr().add(s_off),
2588 grad_y.as_ptr(),
2589 scratch.grad_low_rank.as_mut_ptr(),
2590 n,
2591 n,
2592 );
2593 }
2594 for j in 0..n {
2595 scratch.grad_x2[off + j] += scratch.grad_low_rank[j];
2596 }
2597
2598 let g_state = &mut scratch.grad_att_state.as_mut_slice()[s_off..s_off + n * n];
2599 for irow in 0..n {
2600 let gy = grad_y[irow];
2601 let row_off = irow * n;
2602 for j in 0..n {
2603 g_state[row_off + j] += gy * r_head[j];
2604 }
2605 }
2606
2607 unsafe {
2608 kernel::gemv_avx(
2609 s_old.as_ptr().add(s_off),
2610 kk_head.as_ptr(),
2611 scratch.grad_low_rank.as_mut_ptr(),
2612 n,
2613 n,
2614 );
2615 }
2616 let u = &scratch.grad_low_rank.as_slice()[0..n];
2617
2618 for j in 0..n {
2619 let mut grad_w = 0.0f32;
2620 let mut grad_k = 0.0f32;
2621 let mut grad_b = 0.0f32;
2622 for irow in 0..n {
2623 let g = g_state[irow * n + j];
2624 grad_w += g * s_old[s_off + irow * n + j];
2625 grad_k += g * v_head[irow];
2626 grad_b -= g * u[irow];
2627 future_layer.att_state[s_off + irow * n + j] = g * w_head[j];
2628 }
2629 scratch.grad_param[off + j] += grad_w;
2630 scratch.grad_x3[off + j] += grad_k;
2631 scratch.grad_param2[off + j] += grad_b * a_head[j];
2632 scratch.grad_x5[off + j] += grad_b * kk_head[j];
2633 }
2634
2635 for irow in 0..n {
2636 let mut grad_u = 0.0f32;
2637 for j in 0..n {
2638 grad_u -= g_state[irow * n + j] * kk_head[j] * a_head[j];
2639 }
2640 scratch.grad_low_rank2[irow] = grad_u;
2641 let row_off = irow * n;
2642 for j in 0..n {
2643 future_layer.att_state[s_off + row_off + j] += grad_u * kk_head[j];
2644 }
2645 }
2646 unsafe {
2647 kernel::gemv_t_avx(
2648 s_old.as_ptr().add(s_off),
2649 scratch.grad_low_rank2.as_ptr(),
2650 scratch.grad_low_rank.as_mut_ptr(),
2651 n,
2652 n,
2653 );
2654 }
2655 for j in 0..n {
2656 scratch.grad_param2[off + j] += scratch.grad_low_rank[j];
2657 }
2658
2659 for irow in 0..n {
2660 let mut grad_v = 0.0f32;
2661 for j in 0..n {
2662 grad_v += g_state[irow * n + j] * k_head[j];
2663 }
2664 scratch.grad_x6[off + irow] += grad_v;
2665 }
2666 }
2667
2668 for col in 0..c {
2669 let gk = scratch.grad_x3[col];
2670 let scale = 1.0 + (tr.a[col] - 1.0) * block.attn.k_a[col];
2671 let d_scale = gk * tr.k_pre[col];
2672 scratch.grad_x3[col] = gk * scale;
2673 scratch.grad_x5[col] += d_scale * block.attn.k_a[col];
2674 scratch.grad_param[col] = d_scale * (tr.a[col] - 1.0);
2675 }
2676 for head_idx in 0..h {
2677 let off = head_idx * n;
2678 l2_normalize_backward(
2679 &tr.kk_pre.as_slice()[off..off + n],
2680 &tr.kk.as_slice()[off..off + n],
2681 &scratch.grad_param2.as_slice()[off..off + n],
2682 1e-12,
2683 &mut scratch.grad_x4.as_mut_slice()[off..off + n],
2684 );
2685 }
2686 for col in 0..c {
2687 let g = scratch.grad_x4[col];
2688 scratch.grad_x3[col] += g * block.attn.k_k[col];
2689 scratch.grad_param2[col] = g * tr.k_pre[col];
2690 }
2691 if scope.attn {
2692 add_vec_grad(
2693 block_grads.attn.k_a.as_mut_slice(),
2694 scratch.grad_param.as_slice(),
2695 );
2696 add_vec_grad(
2697 block_grads.attn.k_k.as_mut_slice(),
2698 scratch.grad_param2.as_slice(),
2699 );
2700 }
2701
2702 scratch
2703 .grad_param2
2704 .copy_from_slice(scratch.grad_x6.as_slice());
2705 if layer_idx == 0 {
2706 for col in 0..c {
2707 scratch.grad_x6[col] += scratch.grad_v_first[col];
2708 }
2709 } else if tr.uses_v_residual
2710 && block.attn.v1.is_some()
2711 && block.attn.v2.is_some()
2712 && block.attn.v0.is_some()
2713 {
2714 let v1 = block.attn.v1.as_ref().expect("v1");
2715 let v2 = block.attn.v2.as_ref().expect("v2");
2716 for col in 0..c {
2717 let gv = scratch.grad_param2[col];
2718 let nu = tr.nu[col];
2719 scratch.grad_x6[col] = gv * (1.0 - nu);
2720 scratch.grad_x3[col] = gv * (trace.v_first[col] - tr.v_pre[col]);
2721 scratch.grad_v_first[col] += gv * nu;
2722 }
2723 for col in 0..c {
2724 let nu = tr.nu[col];
2725 scratch.grad_x3[col] *= nu * (1.0 - nu);
2726 }
2727 if scope.attn {
2728 add_vec_grad(
2729 block_grads
2730 .attn
2731 .v0
2732 .as_mut()
2733 .expect("grad v0")
2734 .as_mut_slice(),
2735 scratch.grad_x3.as_slice(),
2736 );
2737 add_outer_grad(
2738 block_grads
2739 .attn
2740 .v2
2741 .as_mut()
2742 .expect("grad v2")
2743 .as_mut_slice(),
2744 c,
2745 d_v,
2746 scratch.grad_x3.as_slice(),
2747 &tr.v_hidden.as_slice()[0..d_v],
2748 );
2749 }
2750 unsafe {
2751 kernel::gemv_t_avx(
2752 v2.as_ptr(),
2753 scratch.grad_x3.as_ptr(),
2754 scratch.grad_low_rank.as_mut_ptr(),
2755 c,
2756 d_v,
2757 );
2758 }
2759 if scope.attn {
2760 add_outer_grad(
2761 block_grads
2762 .attn
2763 .v1
2764 .as_mut()
2765 .expect("grad v1")
2766 .as_mut_slice(),
2767 d_v,
2768 c,
2769 &scratch.grad_low_rank.as_slice()[0..d_v],
2770 tr.xv.as_slice(),
2771 );
2772 }
2773 for col in 0..c {
2774 let mut acc = 0.0f32;
2775 for row in 0..d_v {
2776 acc += v1[row * c + col] * scratch.grad_low_rank[row];
2777 }
2778 scratch.grad_x4[col] += acc;
2779 }
2780 }
2781
2782 let proj_size = c * c;
2783 if scope.attn {
2784 add_outer_grad(
2785 &mut block_grads.attn.rkv_proj.as_mut_slice()[0..proj_size],
2786 c,
2787 c,
2788 scratch.grad_x2.as_slice(),
2789 tr.xr.as_slice(),
2790 );
2791 add_outer_grad(
2792 &mut block_grads.attn.rkv_proj.as_mut_slice()[proj_size..2 * proj_size],
2793 c,
2794 c,
2795 scratch.grad_x3.as_slice(),
2796 tr.xk.as_slice(),
2797 );
2798 add_outer_grad(
2799 &mut block_grads.attn.rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size],
2800 c,
2801 c,
2802 scratch.grad_x6.as_slice(),
2803 tr.xv.as_slice(),
2804 );
2805 }
2806 let proj = block.attn.rkv_proj.as_slice();
2807 unsafe {
2808 kernel::gemv_t_avx(
2809 proj.as_ptr(),
2810 scratch.grad_x2.as_ptr(),
2811 scratch.grad_param.as_mut_ptr(),
2812 c,
2813 c,
2814 );
2815 kernel::gemv_t_avx(
2816 proj.as_ptr().add(proj_size),
2817 scratch.grad_x3.as_ptr(),
2818 scratch.grad_param2.as_mut_ptr(),
2819 c,
2820 c,
2821 );
2822 kernel::gemv_t_avx(
2823 proj.as_ptr().add(2 * proj_size),
2824 scratch.grad_x6.as_ptr(),
2825 scratch.grad_x4.as_mut_ptr(),
2826 c,
2827 c,
2828 );
2829 }
2830
2831 let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
2832 for col in 0..c {
2833 let sig = tr.w_sigmoid[col];
2834 let d_sig = scratch.grad_param[col] * (-inv_sqrt_e) * tr.w_decay[col];
2835 scratch.grad_param[col] = d_sig * sig * (1.0 - sig);
2836 }
2837 if scope.attn {
2838 add_vec_grad(
2839 block_grads.attn.w0.as_mut_slice(),
2840 scratch.grad_param.as_slice(),
2841 );
2842 add_outer_grad(
2843 block_grads.attn.w2.as_mut_slice(),
2844 c,
2845 d_w,
2846 scratch.grad_param.as_slice(),
2847 &tr.w_hidden.as_slice()[0..d_w],
2848 );
2849 }
2850 unsafe {
2851 kernel::gemv_t_avx(
2852 block.attn.w2.as_ptr(),
2853 scratch.grad_param.as_ptr(),
2854 scratch.grad_low_rank.as_mut_ptr(),
2855 c,
2856 d_w,
2857 );
2858 }
2859 for col in 0..d_w {
2860 let t = tr.w_hidden[col];
2861 scratch.grad_low_rank[col] *= 1.0 - t * t;
2862 }
2863 if scope.attn {
2864 add_outer_grad(
2865 block_grads.attn.w1.as_mut_slice(),
2866 d_w,
2867 c,
2868 &scratch.grad_low_rank.as_slice()[0..d_w],
2869 tr.xw.as_slice(),
2870 );
2871 }
2872 unsafe {
2873 kernel::gemv_t_avx(
2874 block.attn.w1.as_ptr(),
2875 scratch.grad_low_rank.as_ptr(),
2876 scratch.grad_x6.as_mut_ptr(),
2877 d_w,
2878 c,
2879 );
2880 }
2881
2882 for col in 0..c {
2883 let a = tr.a[col];
2884 scratch.grad_x5[col] *= a * (1.0 - a);
2885 }
2886 if scope.attn {
2887 add_vec_grad(
2888 block_grads.attn.a0.as_mut_slice(),
2889 scratch.grad_x5.as_slice(),
2890 );
2891 add_outer_grad(
2892 block_grads.attn.a2.as_mut_slice(),
2893 c,
2894 d_a,
2895 scratch.grad_x5.as_slice(),
2896 &tr.a_hidden.as_slice()[0..d_a],
2897 );
2898 }
2899 unsafe {
2900 kernel::gemv_t_avx(
2901 block.attn.a2.as_ptr(),
2902 scratch.grad_x5.as_ptr(),
2903 scratch.grad_low_rank.as_mut_ptr(),
2904 c,
2905 d_a,
2906 );
2907 }
2908 if scope.attn {
2909 add_outer_grad(
2910 block_grads.attn.a1.as_mut_slice(),
2911 d_a,
2912 c,
2913 &scratch.grad_low_rank.as_slice()[0..d_a],
2914 tr.xa.as_slice(),
2915 );
2916 }
2917 unsafe {
2918 kernel::gemv_t_avx(
2919 block.attn.a1.as_ptr(),
2920 scratch.grad_low_rank.as_ptr(),
2921 scratch.grad_x5.as_mut_ptr(),
2922 d_a,
2923 c,
2924 );
2925 }
2926
2927 if scope.attn {
2928 add_outer_grad(
2929 block_grads.attn.g2.as_mut_slice(),
2930 c,
2931 d_g,
2932 scratch.grad_saved.as_slice(),
2933 &tr.g_hidden.as_slice()[0..d_g],
2934 );
2935 }
2936 unsafe {
2937 kernel::gemv_t_avx(
2938 block.attn.g2.as_ptr(),
2939 scratch.grad_saved.as_ptr(),
2940 scratch.grad_low_rank.as_mut_ptr(),
2941 c,
2942 d_g,
2943 );
2944 }
2945 for col in 0..d_g {
2946 let sig = tr.g_hidden[col];
2947 scratch.grad_low_rank2[col] = scratch.grad_low_rank[col] * sig * (1.0 - sig);
2948 }
2949 if scope.attn {
2950 add_outer_grad(
2951 block_grads.attn.g1.as_mut_slice(),
2952 d_g,
2953 c,
2954 &scratch.grad_low_rank2.as_slice()[0..d_g],
2955 tr.xg.as_slice(),
2956 );
2957 }
2958 unsafe {
2959 kernel::gemv_t_avx(
2960 block.attn.g1.as_ptr(),
2961 scratch.grad_low_rank2.as_ptr(),
2962 scratch.grad_saved.as_mut_ptr(),
2963 d_g,
2964 c,
2965 );
2966 }
2967
2968 scratch
2969 .grad_x3
2970 .copy_from_slice(future_layer.att_x_prev.as_slice());
2971 future_layer.att_x_prev.zero();
2972
2973 for col in 0..c {
2974 let g = scratch.grad_param[col];
2975 let mix = block.attn.x_r[col];
2976 let base = tr.attn_norm[col];
2977 let prev = tr.att_x_prev_old[col];
2978 scratch.grad_x3[col] += g * (1.0 - mix);
2979 future_layer.att_x_prev[col] += g * mix;
2980 scratch.grad_x2[col] = g * (prev - base);
2981 }
2982 if scope.attn {
2983 add_vec_grad(
2984 block_grads.attn.x_r.as_mut_slice(),
2985 scratch.grad_x2.as_slice(),
2986 );
2987 }
2988
2989 for col in 0..c {
2990 let g = scratch.grad_x6[col];
2991 let mix = block.attn.x_w[col];
2992 let base = tr.attn_norm[col];
2993 let prev = tr.att_x_prev_old[col];
2994 scratch.grad_x3[col] += g * (1.0 - mix);
2995 future_layer.att_x_prev[col] += g * mix;
2996 scratch.grad_x2[col] = g * (prev - base);
2997 }
2998 if scope.attn {
2999 add_vec_grad(
3000 block_grads.attn.x_w.as_mut_slice(),
3001 scratch.grad_x2.as_slice(),
3002 );
3003 }
3004
3005 for col in 0..c {
3006 let g = scratch.grad_param2[col];
3007 let mix = block.attn.x_k[col];
3008 let base = tr.attn_norm[col];
3009 let prev = tr.att_x_prev_old[col];
3010 scratch.grad_x3[col] += g * (1.0 - mix);
3011 future_layer.att_x_prev[col] += g * mix;
3012 scratch.grad_x2[col] = g * (prev - base);
3013 }
3014 if scope.attn {
3015 add_vec_grad(
3016 block_grads.attn.x_k.as_mut_slice(),
3017 scratch.grad_x2.as_slice(),
3018 );
3019 }
3020
3021 for col in 0..c {
3022 let g = scratch.grad_x4[col];
3023 let mix = block.attn.x_v[col];
3024 let base = tr.attn_norm[col];
3025 let prev = tr.att_x_prev_old[col];
3026 scratch.grad_x3[col] += g * (1.0 - mix);
3027 future_layer.att_x_prev[col] += g * mix;
3028 scratch.grad_x2[col] = g * (prev - base);
3029 }
3030 if scope.attn {
3031 add_vec_grad(
3032 block_grads.attn.x_v.as_mut_slice(),
3033 scratch.grad_x2.as_slice(),
3034 );
3035 }
3036
3037 for col in 0..c {
3038 let g = scratch.grad_x5[col];
3039 let mix = block.attn.x_a[col];
3040 let base = tr.attn_norm[col];
3041 let prev = tr.att_x_prev_old[col];
3042 scratch.grad_x3[col] += g * (1.0 - mix);
3043 future_layer.att_x_prev[col] += g * mix;
3044 scratch.grad_x2[col] = g * (prev - base);
3045 }
3046 if scope.attn {
3047 add_vec_grad(
3048 block_grads.attn.x_a.as_mut_slice(),
3049 scratch.grad_x2.as_slice(),
3050 );
3051 }
3052
3053 for col in 0..c {
3054 let g = scratch.grad_saved[col];
3055 let mix = block.attn.x_g[col];
3056 let base = tr.attn_norm[col];
3057 let prev = tr.att_x_prev_old[col];
3058 scratch.grad_x3[col] += g * (1.0 - mix);
3059 future_layer.att_x_prev[col] += g * mix;
3060 scratch.grad_x2[col] = g * (prev - base);
3061 }
3062 if scope.attn {
3063 add_vec_grad(
3064 block_grads.attn.x_g.as_mut_slice(),
3065 scratch.grad_x2.as_slice(),
3066 );
3067 }
3068
3069 layer_norm_backward(
3070 tr.x_after_pre.as_slice(),
3071 block.attn_norm_w.as_slice(),
3072 scratch.grad_x3.as_slice(),
3073 self.cfg.layer_norm_eps,
3074 scratch.grad_x2.as_mut_slice(),
3075 scratch.grad_x4.as_mut_slice(),
3076 scratch.grad_x5.as_mut_slice(),
3077 );
3078 if scope.attn_norm {
3079 add_vec_grad(
3080 block_grads.attn_norm_w.as_mut_slice(),
3081 scratch.grad_x4.as_slice(),
3082 );
3083 add_vec_grad(
3084 block_grads.attn_norm_b.as_mut_slice(),
3085 scratch.grad_x5.as_slice(),
3086 );
3087 }
3088 for col in 0..c {
3089 scratch.grad_x[col] += scratch.grad_x2[col];
3090 }
3091
3092 if layer_idx == 0
3093 && let (Some(w), Some(_b)) = (&block.pre_norm_w, &block.pre_norm_b)
3094 {
3095 layer_norm_backward(
3096 tr.x_in.as_slice(),
3097 w.as_slice(),
3098 scratch.grad_x.as_slice(),
3099 self.cfg.layer_norm_eps,
3100 scratch.grad_x2.as_mut_slice(),
3101 scratch.grad_x3.as_mut_slice(),
3102 scratch.grad_x4.as_mut_slice(),
3103 );
3104 if scope.pre_norm {
3105 add_vec_grad(
3106 block_grads
3107 .pre_norm_w
3108 .as_mut()
3109 .expect("grad pre_norm_w")
3110 .as_mut_slice(),
3111 scratch.grad_x3.as_slice(),
3112 );
3113 add_vec_grad(
3114 block_grads
3115 .pre_norm_b
3116 .as_mut()
3117 .expect("grad pre_norm_b")
3118 .as_mut_slice(),
3119 scratch.grad_x4.as_slice(),
3120 );
3121 }
3122 scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
3123 }
3124 }
3125
3126 if scope.embed {
3127 let token_idx = trace.token.min(self.cfg.vocab_size.saturating_sub(1));
3128 let off = token_idx * c;
3129 add_vec_grad(
3130 &mut grads.embeddings.as_mut_slice()[off..off + c],
3131 scratch.grad_x.as_slice(),
3132 );
3133 }
3134
3135 Ok(())
3136 }
3137
3138 #[allow(clippy::too_many_arguments)]
3139 pub fn online_train_segment_tbptt(
3141 &mut self,
3142 scratch: &mut ScratchBuffers,
3143 start_state: &State,
3144 steps: &[(u32, u8)],
3145 scope: TrainScopeMask,
3146 optimizer: OptimizerKind,
3147 lr: f32,
3148 clip: f32,
3149 replay_chunk: usize,
3150 adam_t: &mut usize,
3151 model_adam: Option<&mut FullAdamState>,
3152 out_bias: Option<&mut [f32]>,
3153 out_bias_adam_m: Option<&mut [f32]>,
3154 out_bias_adam_v: Option<&mut [f32]>,
3155 live_state_out: &mut State,
3156 ) -> Result<()> {
3157 if steps.is_empty() {
3158 *live_state_out = start_state.clone();
3159 return Ok(());
3160 }
3161
3162 let grad_scale = 1.0f32 / (steps.len() as f32);
3163 let chunk = replay_chunk.max(1).min(steps.len().max(1));
3164 let mut grads = self.new_full_grad_state();
3165 let mut recurrent = self.new_recurrent_grad_state();
3166 recurrent.zero();
3167 let mut bias_grad = out_bias.as_deref().map(|b| vec![0.0f32; b.len()]);
3168
3169 {
3170 let mut checkpoints = Vec::<State>::new();
3171 let mut checkpoint_state = start_state.clone();
3172 scratch.set_capture_train_trace(false);
3173 for chunk_start in (0..steps.len()).step_by(chunk) {
3174 checkpoints.push(checkpoint_state.clone());
3175 let chunk_end = (chunk_start + chunk).min(steps.len());
3176 for &(input_token, _) in &steps[chunk_start..chunk_end] {
3177 self.forward(scratch, input_token, &mut checkpoint_state);
3178 }
3179 }
3180
3181 for chunk_idx in (0..checkpoints.len()).rev() {
3182 let chunk_start = chunk_idx * chunk;
3183 let chunk_end = (chunk_start + chunk).min(steps.len());
3184 let mut state = checkpoints[chunk_idx].clone();
3185 let mut step_states = Vec::<State>::with_capacity(chunk_end - chunk_start + 1);
3186 let mut step_traces =
3187 Vec::<TokenTrainTrace>::with_capacity(chunk_end - chunk_start);
3188 let mut step_pdfs =
3189 Vec::<Vec<f64>>::with_capacity(chunk_end.saturating_sub(chunk_start));
3190 step_states.push(state.clone());
3191
3192 for &(input_token, _) in &steps[chunk_start..chunk_end] {
3193 scratch.set_capture_train_trace(true);
3194 let logits = self.forward(scratch, input_token, &mut state);
3195 let mut pdf = vec![0.0f64; self.cfg.vocab_size];
3196 super::super::softmax_pdf_floor_with_bias(
3197 logits,
3198 out_bias.as_deref(),
3199 &mut pdf,
3200 );
3201 step_pdfs.push(pdf);
3202 step_traces.push(TokenTrainTrace::from_scratch(scratch));
3203 step_states.push(state.clone());
3204 }
3205
3206 for local_idx in (0..step_traces.len()).rev() {
3207 let (_, target_symbol) = steps[chunk_start + local_idx];
3208 self.accumulate_token_step_gradients(
3209 scratch,
3210 &step_traces[local_idx],
3211 &step_states[local_idx + 1],
3212 target_symbol,
3213 &step_pdfs[local_idx],
3214 grad_scale,
3215 scope,
3216 &mut grads,
3217 bias_grad.as_deref_mut(),
3218 &mut recurrent,
3219 )?;
3220 }
3221 }
3222 }
3223
3224 self.apply_full_gradients(
3225 &grads,
3226 scope,
3227 optimizer,
3228 lr,
3229 clip,
3230 adam_t,
3231 model_adam,
3232 out_bias,
3233 bias_grad.as_deref(),
3234 out_bias_adam_m,
3235 out_bias_adam_v,
3236 )?;
3237
3238 scratch.set_capture_train_trace(false);
3239 *live_state_out = start_state.clone();
3240 for &(input_token, _) in steps {
3241 self.forward(scratch, input_token, live_state_out);
3242 }
3243 Ok(())
3244 }
3245
3246 #[allow(clippy::too_many_arguments)]
3248 #[allow(clippy::needless_range_loop)]
3249 pub fn online_train_step_bptt1(
3250 &mut self,
3251 scratch: &mut ScratchBuffers,
3252 state: &State,
3253 symbol: u8,
3254 pdf: &[f64],
3255 scope: TrainScopeMask,
3256 optimizer: OptimizerKind,
3257 lr: f32,
3258 clip: f32,
3259 adam_t: &mut usize,
3260 model_adam: Option<&mut FullAdamState>,
3261 out_bias: Option<&mut [f32]>,
3262 out_bias_adam_m: Option<&mut [f32]>,
3263 out_bias_adam_v: Option<&mut [f32]>,
3264 ) -> Result<()> {
3265 if !scope.trains_any_params() {
3266 return Ok(());
3267 }
3268 if scope.trains_non_head_params() && !scratch.train_trace_valid {
3269 bail!("rwkv full training trace is missing; run one forward step first");
3270 }
3271 let c = self.cfg.hidden_size;
3272 let h = self.cfg.num_heads;
3273 let n = self.cfg.head_dim;
3274 let i = self.cfg.intermediate_size;
3275 let d_w = self.cfg.decay_low_rank;
3276 let d_a = self.cfg.a_low_rank;
3277 let d_v = self.cfg.v_low_rank;
3278 let d_g = self.cfg.g_low_rank;
3279 let vocab = self.cfg.vocab_size.min(pdf.len());
3280 if vocab == 0 {
3281 return Ok(());
3282 }
3283 let mut adam_step = None::<AdamStep>;
3284 let mut model_adam = model_adam;
3285 if matches!(optimizer, OptimizerKind::Adam) {
3286 *adam_t = adam_t.saturating_add(1);
3287 let t = (*adam_t).max(1) as i32;
3288 let b1 = 0.9f32;
3289 let b2 = 0.999f32;
3290 adam_step = Some(AdamStep {
3291 lr,
3292 clip: clip.max(0.0),
3293 b1,
3294 b2,
3295 eps: 1e-8,
3296 bias_corr1: 1.0 - b1.powi(t),
3297 bias_corr2: 1.0 - b2.powi(t),
3298 });
3299 if scope.trains_non_head_params() && model_adam.is_none() {
3300 bail!("rwkv Adam full-training state is missing");
3301 }
3302 }
3303
3304 scratch.grad_logits.zero();
3305 for idx in 0..vocab {
3306 let p = pdf[idx].clamp(1e-12, 1.0) as f32;
3307 let target = if idx == symbol as usize { 1.0 } else { 0.0 };
3308 let mut g = target - p;
3309 if clip > 0.0 {
3310 g = g.clamp(-clip, clip);
3311 }
3312 scratch.grad_logits[idx] = g;
3313 }
3314
3315 if scope.bias
3316 && let Some(bias) = out_bias
3317 {
3318 match optimizer {
3319 OptimizerKind::Sgd => {
3320 for idx in 0..bias.len().min(vocab) {
3321 bias[idx] += lr * scratch.grad_logits[idx];
3322 }
3323 }
3324 OptimizerKind::Adam => {
3325 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3326 let Some(m) = out_bias_adam_m else {
3327 bail!("rwkv Adam output-bias state is missing (m)");
3328 };
3329 let Some(vv) = out_bias_adam_v else {
3330 bail!("rwkv Adam output-bias state is missing (v)");
3331 };
3332 let n = bias.len().min(vocab);
3333 apply_adam_vec_update_raw(
3334 &mut bias[0..n],
3335 &scratch.grad_logits.as_slice()[0..n],
3336 &mut m[0..n],
3337 &mut vv[0..n],
3338 cfg,
3339 );
3340 }
3341 }
3342 }
3343
3344 scratch.grad_x.zero();
3345 if scope.head {
3346 match optimizer {
3347 OptimizerKind::Sgd => {
3348 fused_sgd_head_backward_update(
3349 self.lm_head.as_mut_slice(),
3350 vocab,
3351 c,
3352 &scratch.grad_logits.as_slice()[0..vocab],
3353 scratch.x_normed.as_slice(),
3354 scratch.grad_x.as_mut_slice(),
3355 lr,
3356 clip,
3357 );
3358 }
3359 OptimizerKind::Adam => {
3360 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3361 let adam = model_adam.as_mut().expect("adam state exists");
3362 fused_adam_head_backward_update(
3363 self.lm_head.as_mut_slice(),
3364 vocab,
3365 c,
3366 &scratch.grad_logits.as_slice()[0..vocab],
3367 scratch.x_normed.as_slice(),
3368 scratch.grad_x.as_mut_slice(),
3369 adam.lm_head.m.as_mut_slice(),
3370 adam.lm_head.v.as_mut_slice(),
3371 cfg,
3372 );
3373 }
3374 }
3375 } else {
3376 for row in 0..vocab {
3377 let g = scratch.grad_logits[row];
3378 if g == 0.0 {
3379 continue;
3380 }
3381 let row_off = row * c;
3382 for col in 0..c {
3383 scratch.grad_x[col] += self.lm_head[row_off + col] * g;
3384 }
3385 }
3386 }
3387
3388 let needs_backprop = scope.trains_non_head_params() || scope.head;
3389 if !needs_backprop {
3390 return Ok(());
3391 }
3392 layer_norm_backward(
3393 scratch.x.as_slice(),
3394 self.ln_out_w.as_slice(),
3395 scratch.grad_x.as_slice(),
3396 self.cfg.layer_norm_eps,
3397 scratch.grad_x2.as_mut_slice(),
3398 scratch.grad_x3.as_mut_slice(),
3399 scratch.grad_x4.as_mut_slice(),
3400 );
3401 if scope.head {
3402 match optimizer {
3403 OptimizerKind::Sgd => {
3404 sgd_vec_update(
3405 self.ln_out_w.as_mut_slice(),
3406 scratch.grad_x3.as_slice(),
3407 lr,
3408 clip,
3409 );
3410 sgd_vec_update(
3411 self.ln_out_b.as_mut_slice(),
3412 scratch.grad_x4.as_slice(),
3413 lr,
3414 clip,
3415 );
3416 }
3417 OptimizerKind::Adam => {
3418 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3419 let adam = model_adam.as_mut().expect("adam state exists");
3420 apply_adam_vec_update(
3421 self.ln_out_w.as_mut_slice(),
3422 scratch.grad_x3.as_slice(),
3423 &mut adam.ln_out_w,
3424 cfg,
3425 );
3426 apply_adam_vec_update(
3427 self.ln_out_b.as_mut_slice(),
3428 scratch.grad_x4.as_slice(),
3429 &mut adam.ln_out_b,
3430 cfg,
3431 );
3432 }
3433 }
3434 }
3435 scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
3436 scratch.grad_v_first.zero();
3437
3438 for layer_idx in (0..self.cfg.num_layers).rev() {
3439 let tr = &scratch.train_trace_layers[layer_idx];
3440 let block = &mut self.blocks[layer_idx];
3441
3442 scratch.grad_x2.copy_from_slice(scratch.grad_x.as_slice()); scratch.grad_x3.copy_from_slice(scratch.grad_x.as_slice()); unsafe {
3448 kernel::gemv_t_avx(
3449 block.ffn.value_w.as_ptr(),
3450 scratch.grad_x3.as_ptr(),
3451 scratch.grad_ffn.as_mut_ptr(),
3452 c,
3453 i,
3454 );
3455 }
3456 if scope.ffn {
3457 match optimizer {
3458 OptimizerKind::Sgd => sgd_outer_update(
3459 block.ffn.value_w.as_mut_slice(),
3460 c,
3461 i,
3462 scratch.grad_x3.as_slice(),
3463 tr.ffn_k.as_slice(),
3464 lr,
3465 clip,
3466 ),
3467 OptimizerKind::Adam => {
3468 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3469 let adam =
3470 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3471 apply_adam_outer_update(
3472 block.ffn.value_w.as_mut_slice(),
3473 c,
3474 i,
3475 scratch.grad_x3.as_slice(),
3476 tr.ffn_k.as_slice(),
3477 &mut adam.ffn.value_w,
3478 cfg,
3479 );
3480 }
3481 }
3482 }
3483
3484 for col in 0..i {
3486 let pre = tr.ffn_pre[col];
3487 scratch.grad_ffn2[col] = if pre > 0.0 {
3488 scratch.grad_ffn[col] * (2.0 * pre)
3489 } else {
3490 0.0
3491 };
3492 }
3493
3494 unsafe {
3496 kernel::gemv_t_avx(
3497 block.ffn.key_w.as_ptr(),
3498 scratch.grad_ffn2.as_ptr(),
3499 scratch.grad_x4.as_mut_ptr(),
3500 i,
3501 c,
3502 );
3503 }
3504 if scope.ffn {
3505 match optimizer {
3506 OptimizerKind::Sgd => sgd_outer_update(
3507 block.ffn.key_w.as_mut_slice(),
3508 i,
3509 c,
3510 scratch.grad_ffn2.as_slice(),
3511 tr.ffn_xk.as_slice(),
3512 lr,
3513 clip,
3514 ),
3515 OptimizerKind::Adam => {
3516 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3517 let adam =
3518 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3519 apply_adam_outer_update(
3520 block.ffn.key_w.as_mut_slice(),
3521 i,
3522 c,
3523 scratch.grad_ffn2.as_slice(),
3524 tr.ffn_xk.as_slice(),
3525 &mut adam.ffn.key_w,
3526 cfg,
3527 );
3528 }
3529 }
3530 }
3531
3532 for col in 0..c {
3534 let g = scratch.grad_x4[col];
3535 let mix = block.ffn.x_k[col];
3536 let base = tr.ffn_norm[col];
3537 let prev = tr.ffn_x_prev_old[col];
3538 scratch.grad_x5[col] = g * (1.0 - mix); scratch.grad_param[col] = g * (prev - base); }
3541 if scope.ffn {
3542 match optimizer {
3543 OptimizerKind::Sgd => sgd_vec_update(
3544 block.ffn.x_k.as_mut_slice(),
3545 scratch.grad_param.as_slice(),
3546 lr,
3547 clip,
3548 ),
3549 OptimizerKind::Adam => {
3550 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3551 let adam =
3552 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3553 apply_adam_vec_update(
3554 block.ffn.x_k.as_mut_slice(),
3555 scratch.grad_param.as_slice(),
3556 &mut adam.ffn.x_k,
3557 cfg,
3558 );
3559 }
3560 }
3561 }
3562
3563 layer_norm_backward(
3565 tr.x_after_attn.as_slice(),
3566 block.ffn_norm_w.as_slice(),
3567 scratch.grad_x5.as_slice(),
3568 self.cfg.layer_norm_eps,
3569 scratch.grad_x4.as_mut_slice(),
3570 scratch.grad_x3.as_mut_slice(),
3571 scratch.grad_x6.as_mut_slice(),
3572 );
3573 if scope.ffn_norm {
3574 match optimizer {
3575 OptimizerKind::Sgd => {
3576 sgd_vec_update(
3577 block.ffn_norm_w.as_mut_slice(),
3578 scratch.grad_x3.as_slice(),
3579 lr,
3580 clip,
3581 );
3582 sgd_vec_update(
3583 block.ffn_norm_b.as_mut_slice(),
3584 scratch.grad_x6.as_slice(),
3585 lr,
3586 clip,
3587 );
3588 }
3589 OptimizerKind::Adam => {
3590 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3591 let adam =
3592 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3593 apply_adam_vec_update(
3594 block.ffn_norm_w.as_mut_slice(),
3595 scratch.grad_x3.as_slice(),
3596 &mut adam.ffn_norm_w,
3597 cfg,
3598 );
3599 apply_adam_vec_update(
3600 block.ffn_norm_b.as_mut_slice(),
3601 scratch.grad_x6.as_slice(),
3602 &mut adam.ffn_norm_b,
3603 cfg,
3604 );
3605 }
3606 }
3607 }
3608 for col in 0..c {
3609 scratch.grad_x2[col] += scratch.grad_x4[col];
3610 }
3611
3612 scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice()); scratch.grad_x3.copy_from_slice(scratch.grad_x2.as_slice()); unsafe {
3618 kernel::gemv_t_avx(
3619 block.attn.o_proj.as_ptr(),
3620 scratch.grad_x3.as_ptr(),
3621 scratch.grad_x4.as_mut_ptr(),
3622 c,
3623 c,
3624 );
3625 }
3626 if scope.attn {
3627 match optimizer {
3628 OptimizerKind::Sgd => sgd_outer_update(
3629 block.attn.o_proj.as_mut_slice(),
3630 c,
3631 c,
3632 scratch.grad_x3.as_slice(),
3633 tr.y_gate.as_slice(),
3634 lr,
3635 clip,
3636 ),
3637 OptimizerKind::Adam => {
3638 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3639 let adam =
3640 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3641 apply_adam_outer_update(
3642 block.attn.o_proj.as_mut_slice(),
3643 c,
3644 c,
3645 scratch.grad_x3.as_slice(),
3646 tr.y_gate.as_slice(),
3647 &mut adam.attn.o_proj,
3648 cfg,
3649 );
3650 }
3651 }
3652 }
3653
3654 for col in 0..c {
3656 let gy = scratch.grad_x4[col];
3657 scratch.grad_saved[col] = gy * tr.y_head[col]; scratch.grad_x4[col] = gy * tr.g[col]; }
3660
3661 scratch.grad_x2.zero(); scratch.grad_x3.zero(); scratch.grad_x6.zero(); scratch.grad_param.zero(); for head_idx in 0..h {
3667 let off = head_idx * n;
3668 let mut g_alpha = 0.0f32;
3669 for j in 0..n {
3670 let g = scratch.grad_x4[off + j];
3671 g_alpha += g * tr.v[off + j];
3672 scratch.grad_x6[off + j] += g * tr.alpha[head_idx];
3673 }
3674 for j in 0..n {
3675 let idx = off + j;
3676 let rk = block.attn.r_k[idx];
3677 let rv = tr.r[idx];
3678 let kv = tr.k[idx];
3679 let g = g_alpha * rk;
3680 scratch.grad_x2[idx] += g * kv;
3681 scratch.grad_x3[idx] += g * rv;
3682 scratch.grad_param[idx] += g_alpha * rv * kv;
3683 }
3684 }
3685 if scope.attn {
3686 match optimizer {
3687 OptimizerKind::Sgd => sgd_vec_update(
3688 block.attn.r_k.as_mut_slice(),
3689 scratch.grad_param.as_slice(),
3690 lr,
3691 clip,
3692 ),
3693 OptimizerKind::Adam => {
3694 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3695 let adam =
3696 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3697 apply_adam_vec_update(
3698 block.attn.r_k.as_mut_slice(),
3699 scratch.grad_param.as_slice(),
3700 &mut adam.attn.r_k,
3701 cfg,
3702 );
3703 }
3704 }
3705 }
3706
3707 scratch.grad_x5.as_mut_slice()[0..c].copy_from_slice(&scratch.grad_x4.as_slice()[0..c]);
3709 group_norm_backward(
3710 tr.y_wkv.as_slice(),
3711 block.attn.g_norm_w.as_slice(),
3712 scratch.grad_x5.as_slice(),
3713 h,
3714 n,
3715 self.cfg.group_norm_eps,
3716 scratch.grad_x4.as_mut_slice(), scratch.grad_param.as_mut_slice(), scratch.grad_param2.as_mut_slice(), );
3720 if scope.attn {
3721 match optimizer {
3722 OptimizerKind::Sgd => {
3723 sgd_vec_update(
3724 block.attn.g_norm_w.as_mut_slice(),
3725 scratch.grad_param.as_slice(),
3726 lr,
3727 clip,
3728 );
3729 sgd_vec_update(
3730 block.attn.g_norm_b.as_mut_slice(),
3731 scratch.grad_param2.as_slice(),
3732 lr,
3733 clip,
3734 );
3735 }
3736 OptimizerKind::Adam => {
3737 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3738 let adam =
3739 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3740 apply_adam_vec_update(
3741 block.attn.g_norm_w.as_mut_slice(),
3742 scratch.grad_param.as_slice(),
3743 &mut adam.attn.g_norm_w,
3744 cfg,
3745 );
3746 apply_adam_vec_update(
3747 block.attn.g_norm_b.as_mut_slice(),
3748 scratch.grad_param2.as_slice(),
3749 &mut adam.attn.g_norm_b,
3750 cfg,
3751 );
3752 }
3753 }
3754 }
3755
3756 scratch.grad_param.zero(); scratch.grad_x5.zero(); scratch.grad_param2.zero(); let s_old = tr.att_state_old.as_slice();
3761 let s_new = state.layers[layer_idx].att_state.as_slice();
3762 for head_idx in 0..h {
3763 let off = head_idx * n;
3764 let s_head_old_off = head_idx * n * n;
3765 let s_head_new_off = head_idx * n * n;
3766 let grad_y = &scratch.grad_x4.as_slice()[off..off + n];
3767 let r_head = &tr.r.as_slice()[off..off + n];
3768 let k_head = &tr.k.as_slice()[off..off + n];
3769 let kk_head = &tr.kk.as_slice()[off..off + n];
3770 let a_head = &tr.a.as_slice()[off..off + n];
3771 let v_head = &tr.v.as_slice()[off..off + n];
3772
3773 unsafe {
3774 kernel::gemv_t_avx(
3775 s_new.as_ptr().add(s_head_new_off),
3776 grad_y.as_ptr(),
3777 scratch.grad_low_rank.as_mut_ptr(),
3778 n,
3779 n,
3780 );
3781 kernel::gemv_t_avx(
3782 s_old.as_ptr().add(s_head_old_off),
3783 grad_y.as_ptr(),
3784 scratch.grad_low_rank2.as_mut_ptr(),
3785 n,
3786 n,
3787 );
3788 }
3789
3790 for j in 0..n {
3791 let idx = off + j;
3792 scratch.grad_x2[idx] += scratch.grad_low_rank[j];
3793 scratch.grad_param[idx] += r_head[j] * scratch.grad_low_rank2[j];
3794 }
3795
3796 unsafe {
3797 kernel::gemv_avx(
3798 s_old.as_ptr().add(s_head_old_off),
3799 kk_head.as_ptr(),
3800 scratch.grad_low_rank.as_mut_ptr(),
3801 n,
3802 n,
3803 );
3804 }
3805
3806 let mut dot_gv = 0.0f32;
3807 let mut dot_rk = 0.0f32;
3808 let mut dot_r_kka = 0.0f32;
3809 let mut sum_gy_u = 0.0f32;
3810 for j in 0..n {
3811 dot_gv += grad_y[j] * v_head[j];
3812 dot_rk += r_head[j] * k_head[j];
3813 dot_r_kka += r_head[j] * kk_head[j] * a_head[j];
3814 sum_gy_u += grad_y[j] * scratch.grad_low_rank[j];
3815 }
3816
3817 for j in 0..n {
3818 let idx = off + j;
3819 scratch.grad_x3[idx] += r_head[j] * dot_gv;
3820 scratch.grad_x6[idx] += grad_y[j] * dot_rk;
3821 scratch.grad_x5[idx] -= sum_gy_u * r_head[j] * kk_head[j];
3822 scratch.grad_low_rank[j] = -grad_y[j] * dot_r_kka;
3823 }
3824
3825 unsafe {
3826 kernel::gemv_t_avx(
3827 s_old.as_ptr().add(s_head_old_off),
3828 scratch.grad_low_rank.as_ptr(),
3829 scratch.grad_low_rank2.as_mut_ptr(),
3830 n,
3831 n,
3832 );
3833 }
3834 for j in 0..n {
3835 let idx = off + j;
3836 scratch.grad_param2[idx] +=
3837 scratch.grad_low_rank2[j] - sum_gy_u * r_head[j] * a_head[j];
3838 }
3839 }
3840
3841 for col in 0..c {
3843 let gk = scratch.grad_x3[col];
3844 let scale = 1.0 + (tr.a[col] - 1.0) * block.attn.k_a[col];
3845 let d_scale = gk * tr.k_pre[col];
3846 scratch.grad_x3[col] = gk * scale; scratch.grad_x5[col] += d_scale * block.attn.k_a[col]; scratch.grad_param[col] = d_scale * (tr.a[col] - 1.0); }
3850 for head_idx in 0..h {
3851 let off = head_idx * n;
3852 l2_normalize_backward(
3853 &tr.kk_pre.as_slice()[off..off + n],
3854 &tr.kk.as_slice()[off..off + n],
3855 &scratch.grad_param2.as_slice()[off..off + n],
3856 1e-12,
3857 &mut scratch.grad_x4.as_mut_slice()[off..off + n],
3858 );
3859 }
3860 for col in 0..c {
3861 let g = scratch.grad_x4[col];
3862 scratch.grad_x3[col] += g * block.attn.k_k[col]; scratch.grad_param2[col] = g * tr.k_pre[col]; }
3865 if scope.attn {
3866 match optimizer {
3867 OptimizerKind::Sgd => {
3868 sgd_vec_update(
3869 block.attn.k_a.as_mut_slice(),
3870 scratch.grad_param.as_slice(),
3871 lr,
3872 clip,
3873 );
3874 sgd_vec_update(
3875 block.attn.k_k.as_mut_slice(),
3876 scratch.grad_param2.as_slice(),
3877 lr,
3878 clip,
3879 );
3880 }
3881 OptimizerKind::Adam => {
3882 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3883 let adam =
3884 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
3885 apply_adam_vec_update(
3886 block.attn.k_a.as_mut_slice(),
3887 scratch.grad_param.as_slice(),
3888 &mut adam.attn.k_a,
3889 cfg,
3890 );
3891 apply_adam_vec_update(
3892 block.attn.k_k.as_mut_slice(),
3893 scratch.grad_param2.as_slice(),
3894 &mut adam.attn.k_k,
3895 cfg,
3896 );
3897 }
3898 }
3899 }
3900
3901 scratch
3903 .grad_param2
3904 .copy_from_slice(scratch.grad_x6.as_slice()); if layer_idx == 0 {
3906 for col in 0..c {
3907 scratch.grad_x6[col] += scratch.grad_v_first[col];
3908 }
3909 } else if tr.uses_v_residual
3910 && let (Some(v1), Some(v2), Some(v0)) =
3911 (&mut block.attn.v1, &mut block.attn.v2, &mut block.attn.v0)
3912 {
3913 for col in 0..c {
3914 let gv = scratch.grad_param2[col];
3915 let nu = tr.nu[col];
3916 scratch.grad_x6[col] = gv * (1.0 - nu); scratch.grad_x3[col] = gv * (scratch.train_v_first[col] - tr.v_pre[col]); scratch.grad_v_first[col] += gv * nu; }
3920 for col in 0..c {
3921 let nu = tr.nu[col];
3922 scratch.grad_x3[col] *= nu * (1.0 - nu); }
3924 if scope.attn {
3925 match optimizer {
3926 OptimizerKind::Sgd => {
3927 sgd_vec_update(v0.as_mut_slice(), scratch.grad_x3.as_slice(), lr, clip)
3928 }
3929 OptimizerKind::Adam => {
3930 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3931 let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
3932 [layer_idx];
3933 apply_adam_vec_update(
3934 v0.as_mut_slice(),
3935 scratch.grad_x3.as_slice(),
3936 adam.attn.v0.as_mut().expect("adam v0 state"),
3937 cfg,
3938 );
3939 }
3940 }
3941 }
3942 if scope.attn {
3943 match optimizer {
3944 OptimizerKind::Sgd => sgd_outer_update(
3945 v2.as_mut_slice(),
3946 c,
3947 d_v,
3948 scratch.grad_x3.as_slice(),
3949 &tr.v_hidden.as_slice()[0..d_v],
3950 lr,
3951 clip,
3952 ),
3953 OptimizerKind::Adam => {
3954 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3955 let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
3956 [layer_idx];
3957 apply_adam_outer_update(
3958 v2.as_mut_slice(),
3959 c,
3960 d_v,
3961 scratch.grad_x3.as_slice(),
3962 &tr.v_hidden.as_slice()[0..d_v],
3963 adam.attn.v2.as_mut().expect("adam v2 state"),
3964 cfg,
3965 );
3966 }
3967 }
3968 }
3969 unsafe {
3970 kernel::gemv_t_avx(
3971 v2.as_ptr(),
3972 scratch.grad_x3.as_ptr(),
3973 scratch.grad_low_rank.as_mut_ptr(),
3974 c,
3975 d_v,
3976 );
3977 }
3978 if scope.attn {
3979 match optimizer {
3980 OptimizerKind::Sgd => sgd_outer_update(
3981 v1.as_mut_slice(),
3982 d_v,
3983 c,
3984 &scratch.grad_low_rank.as_slice()[0..d_v],
3985 tr.xv.as_slice(),
3986 lr,
3987 clip,
3988 ),
3989 OptimizerKind::Adam => {
3990 let cfg = adam_step.as_ref().expect("adam cfg initialized");
3991 let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
3992 [layer_idx];
3993 apply_adam_outer_update(
3994 v1.as_mut_slice(),
3995 d_v,
3996 c,
3997 &scratch.grad_low_rank.as_slice()[0..d_v],
3998 tr.xv.as_slice(),
3999 adam.attn.v1.as_mut().expect("adam v1 state"),
4000 cfg,
4001 );
4002 }
4003 }
4004 }
4005 for col in 0..c {
4006 let mut acc = 0.0f32;
4007 for row in 0..d_v {
4008 acc += v1[row * c + col] * scratch.grad_low_rank[row];
4009 }
4010 scratch.grad_x4[col] += acc; }
4012 }
4013
4014 let proj_size = c * c;
4016 if scope.attn {
4017 match optimizer {
4018 OptimizerKind::Sgd => {
4019 sgd_outer_update(
4020 &mut block.attn.rkv_proj.as_mut_slice()[0..proj_size],
4021 c,
4022 c,
4023 scratch.grad_x2.as_slice(),
4024 tr.xr.as_slice(),
4025 lr,
4026 clip,
4027 );
4028 sgd_outer_update(
4029 &mut block.attn.rkv_proj.as_mut_slice()[proj_size..2 * proj_size],
4030 c,
4031 c,
4032 scratch.grad_x3.as_slice(),
4033 tr.xk.as_slice(),
4034 lr,
4035 clip,
4036 );
4037 sgd_outer_update(
4038 &mut block.attn.rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size],
4039 c,
4040 c,
4041 scratch.grad_x6.as_slice(),
4042 tr.xv.as_slice(),
4043 lr,
4044 clip,
4045 );
4046 }
4047 OptimizerKind::Adam => {
4048 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4049 let adam =
4050 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4051 apply_adam_outer_update_raw(
4052 &mut block.attn.rkv_proj.as_mut_slice()[0..proj_size],
4053 c,
4054 c,
4055 scratch.grad_x2.as_slice(),
4056 tr.xr.as_slice(),
4057 &mut adam.attn.rkv_proj.m.as_mut_slice()[0..proj_size],
4058 &mut adam.attn.rkv_proj.v.as_mut_slice()[0..proj_size],
4059 cfg,
4060 );
4061 apply_adam_outer_update_raw(
4062 &mut block.attn.rkv_proj.as_mut_slice()[proj_size..2 * proj_size],
4063 c,
4064 c,
4065 scratch.grad_x3.as_slice(),
4066 tr.xk.as_slice(),
4067 &mut adam.attn.rkv_proj.m.as_mut_slice()[proj_size..2 * proj_size],
4068 &mut adam.attn.rkv_proj.v.as_mut_slice()[proj_size..2 * proj_size],
4069 cfg,
4070 );
4071 apply_adam_outer_update_raw(
4072 &mut block.attn.rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size],
4073 c,
4074 c,
4075 scratch.grad_x6.as_slice(),
4076 tr.xv.as_slice(),
4077 &mut adam.attn.rkv_proj.m.as_mut_slice()[2 * proj_size..3 * proj_size],
4078 &mut adam.attn.rkv_proj.v.as_mut_slice()[2 * proj_size..3 * proj_size],
4079 cfg,
4080 );
4081 }
4082 }
4083 }
4084 let proj = block.attn.rkv_proj.as_slice();
4085 unsafe {
4086 kernel::gemv_t_avx(
4087 proj.as_ptr(),
4088 scratch.grad_x2.as_ptr(),
4089 scratch.grad_param.as_mut_ptr(),
4090 c,
4091 c,
4092 );
4093 kernel::gemv_t_avx(
4094 proj.as_ptr().add(proj_size),
4095 scratch.grad_x3.as_ptr(),
4096 scratch.grad_param2.as_mut_ptr(),
4097 c,
4098 c,
4099 );
4100 kernel::gemv_t_avx(
4101 proj.as_ptr().add(2 * proj_size),
4102 scratch.grad_x6.as_ptr(),
4103 scratch.grad_x4.as_mut_ptr(),
4104 c,
4105 c,
4106 );
4107 }
4108
4109 let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
4111 for col in 0..c {
4112 let sig = tr.w_sigmoid[col];
4113 let d_sig = scratch.grad_param[col] * (-inv_sqrt_e) * tr.w_decay[col];
4114 scratch.grad_param[col] = d_sig * sig * (1.0 - sig); }
4116 if scope.attn {
4117 match optimizer {
4118 OptimizerKind::Sgd => sgd_vec_update(
4119 block.attn.w0.as_mut_slice(),
4120 scratch.grad_param.as_slice(),
4121 lr,
4122 clip,
4123 ),
4124 OptimizerKind::Adam => {
4125 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4126 let adam =
4127 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4128 apply_adam_vec_update(
4129 block.attn.w0.as_mut_slice(),
4130 scratch.grad_param.as_slice(),
4131 &mut adam.attn.w0,
4132 cfg,
4133 );
4134 }
4135 }
4136 match optimizer {
4137 OptimizerKind::Sgd => sgd_outer_update(
4138 block.attn.w2.as_mut_slice(),
4139 c,
4140 d_w,
4141 scratch.grad_param.as_slice(),
4142 &tr.w_hidden.as_slice()[0..d_w],
4143 lr,
4144 clip,
4145 ),
4146 OptimizerKind::Adam => {
4147 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4148 let adam =
4149 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4150 apply_adam_outer_update(
4151 block.attn.w2.as_mut_slice(),
4152 c,
4153 d_w,
4154 scratch.grad_param.as_slice(),
4155 &tr.w_hidden.as_slice()[0..d_w],
4156 &mut adam.attn.w2,
4157 cfg,
4158 );
4159 }
4160 }
4161 }
4162 unsafe {
4163 kernel::gemv_t_avx(
4164 block.attn.w2.as_ptr(),
4165 scratch.grad_param.as_ptr(),
4166 scratch.grad_low_rank.as_mut_ptr(),
4167 c,
4168 d_w,
4169 );
4170 }
4171 for col in 0..d_w {
4172 let t = tr.w_hidden[col];
4173 scratch.grad_low_rank[col] *= 1.0 - t * t;
4174 }
4175 if scope.attn {
4176 match optimizer {
4177 OptimizerKind::Sgd => sgd_outer_update(
4178 block.attn.w1.as_mut_slice(),
4179 d_w,
4180 c,
4181 &scratch.grad_low_rank.as_slice()[0..d_w],
4182 tr.xw.as_slice(),
4183 lr,
4184 clip,
4185 ),
4186 OptimizerKind::Adam => {
4187 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4188 let adam =
4189 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4190 apply_adam_outer_update(
4191 block.attn.w1.as_mut_slice(),
4192 d_w,
4193 c,
4194 &scratch.grad_low_rank.as_slice()[0..d_w],
4195 tr.xw.as_slice(),
4196 &mut adam.attn.w1,
4197 cfg,
4198 );
4199 }
4200 }
4201 }
4202 unsafe {
4203 kernel::gemv_t_avx(
4204 block.attn.w1.as_ptr(),
4205 scratch.grad_low_rank.as_ptr(),
4206 scratch.grad_x6.as_mut_ptr(),
4207 d_w,
4208 c,
4209 );
4210 }
4211
4212 for col in 0..c {
4214 let a = tr.a[col];
4215 scratch.grad_x5[col] *= a * (1.0 - a); }
4217 if scope.attn {
4218 match optimizer {
4219 OptimizerKind::Sgd => sgd_vec_update(
4220 block.attn.a0.as_mut_slice(),
4221 scratch.grad_x5.as_slice(),
4222 lr,
4223 clip,
4224 ),
4225 OptimizerKind::Adam => {
4226 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4227 let adam =
4228 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4229 apply_adam_vec_update(
4230 block.attn.a0.as_mut_slice(),
4231 scratch.grad_x5.as_slice(),
4232 &mut adam.attn.a0,
4233 cfg,
4234 );
4235 }
4236 }
4237 match optimizer {
4238 OptimizerKind::Sgd => sgd_outer_update(
4239 block.attn.a2.as_mut_slice(),
4240 c,
4241 d_a,
4242 scratch.grad_x5.as_slice(),
4243 &tr.a_hidden.as_slice()[0..d_a],
4244 lr,
4245 clip,
4246 ),
4247 OptimizerKind::Adam => {
4248 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4249 let adam =
4250 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4251 apply_adam_outer_update(
4252 block.attn.a2.as_mut_slice(),
4253 c,
4254 d_a,
4255 scratch.grad_x5.as_slice(),
4256 &tr.a_hidden.as_slice()[0..d_a],
4257 &mut adam.attn.a2,
4258 cfg,
4259 );
4260 }
4261 }
4262 }
4263 unsafe {
4264 kernel::gemv_t_avx(
4265 block.attn.a2.as_ptr(),
4266 scratch.grad_x5.as_ptr(),
4267 scratch.grad_low_rank.as_mut_ptr(),
4268 c,
4269 d_a,
4270 );
4271 }
4272 if scope.attn {
4273 match optimizer {
4274 OptimizerKind::Sgd => sgd_outer_update(
4275 block.attn.a1.as_mut_slice(),
4276 d_a,
4277 c,
4278 &scratch.grad_low_rank.as_slice()[0..d_a],
4279 tr.xa.as_slice(),
4280 lr,
4281 clip,
4282 ),
4283 OptimizerKind::Adam => {
4284 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4285 let adam =
4286 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4287 apply_adam_outer_update(
4288 block.attn.a1.as_mut_slice(),
4289 d_a,
4290 c,
4291 &scratch.grad_low_rank.as_slice()[0..d_a],
4292 tr.xa.as_slice(),
4293 &mut adam.attn.a1,
4294 cfg,
4295 );
4296 }
4297 }
4298 }
4299 unsafe {
4300 kernel::gemv_t_avx(
4301 block.attn.a1.as_ptr(),
4302 scratch.grad_low_rank.as_ptr(),
4303 scratch.grad_x5.as_mut_ptr(),
4304 d_a,
4305 c,
4306 );
4307 }
4308
4309 if scope.attn {
4311 match optimizer {
4312 OptimizerKind::Sgd => sgd_outer_update(
4313 block.attn.g2.as_mut_slice(),
4314 c,
4315 d_g,
4316 scratch.grad_saved.as_slice(),
4317 &tr.g_hidden.as_slice()[0..d_g],
4318 lr,
4319 clip,
4320 ),
4321 OptimizerKind::Adam => {
4322 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4323 let adam =
4324 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4325 apply_adam_outer_update(
4326 block.attn.g2.as_mut_slice(),
4327 c,
4328 d_g,
4329 scratch.grad_saved.as_slice(),
4330 &tr.g_hidden.as_slice()[0..d_g],
4331 &mut adam.attn.g2,
4332 cfg,
4333 );
4334 }
4335 }
4336 }
4337 unsafe {
4338 kernel::gemv_t_avx(
4339 block.attn.g2.as_ptr(),
4340 scratch.grad_saved.as_ptr(),
4341 scratch.grad_low_rank.as_mut_ptr(),
4342 c,
4343 d_g,
4344 );
4345 }
4346 for col in 0..d_g {
4347 let sig = tr.g_hidden[col];
4348 scratch.grad_low_rank2[col] = scratch.grad_low_rank[col] * sig * (1.0 - sig);
4349 }
4350 if scope.attn {
4351 match optimizer {
4352 OptimizerKind::Sgd => sgd_outer_update(
4353 block.attn.g1.as_mut_slice(),
4354 d_g,
4355 c,
4356 &scratch.grad_low_rank2.as_slice()[0..d_g],
4357 tr.xg.as_slice(),
4358 lr,
4359 clip,
4360 ),
4361 OptimizerKind::Adam => {
4362 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4363 let adam =
4364 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4365 apply_adam_outer_update(
4366 block.attn.g1.as_mut_slice(),
4367 d_g,
4368 c,
4369 &scratch.grad_low_rank2.as_slice()[0..d_g],
4370 tr.xg.as_slice(),
4371 &mut adam.attn.g1,
4372 cfg,
4373 );
4374 }
4375 }
4376 }
4377 unsafe {
4378 kernel::gemv_t_avx(
4379 block.attn.g1.as_ptr(),
4380 scratch.grad_low_rank2.as_ptr(),
4381 scratch.grad_saved.as_mut_ptr(),
4382 d_g,
4383 c,
4384 );
4385 }
4386
4387 scratch.grad_x3.zero(); for col in 0..c {
4392 let g = scratch.grad_param[col];
4393 let mix = block.attn.x_r[col];
4394 let base = tr.attn_norm[col];
4395 let prev = tr.att_x_prev_old[col];
4396 scratch.grad_x3[col] += g * (1.0 - mix);
4397 scratch.grad_x2[col] = g * (prev - base);
4398 }
4399 if scope.attn {
4400 match optimizer {
4401 OptimizerKind::Sgd => sgd_vec_update(
4402 block.attn.x_r.as_mut_slice(),
4403 scratch.grad_x2.as_slice(),
4404 lr,
4405 clip,
4406 ),
4407 OptimizerKind::Adam => {
4408 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4409 let adam =
4410 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4411 apply_adam_vec_update(
4412 block.attn.x_r.as_mut_slice(),
4413 scratch.grad_x2.as_slice(),
4414 &mut adam.attn.x_r,
4415 cfg,
4416 );
4417 }
4418 }
4419 }
4420
4421 for col in 0..c {
4423 let g = scratch.grad_x6[col];
4424 let mix = block.attn.x_w[col];
4425 let base = tr.attn_norm[col];
4426 let prev = tr.att_x_prev_old[col];
4427 scratch.grad_x3[col] += g * (1.0 - mix);
4428 scratch.grad_x2[col] = g * (prev - base);
4429 }
4430 if scope.attn {
4431 match optimizer {
4432 OptimizerKind::Sgd => sgd_vec_update(
4433 block.attn.x_w.as_mut_slice(),
4434 scratch.grad_x2.as_slice(),
4435 lr,
4436 clip,
4437 ),
4438 OptimizerKind::Adam => {
4439 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4440 let adam =
4441 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4442 apply_adam_vec_update(
4443 block.attn.x_w.as_mut_slice(),
4444 scratch.grad_x2.as_slice(),
4445 &mut adam.attn.x_w,
4446 cfg,
4447 );
4448 }
4449 }
4450 }
4451
4452 for col in 0..c {
4454 let g = scratch.grad_param2[col];
4455 let mix = block.attn.x_k[col];
4456 let base = tr.attn_norm[col];
4457 let prev = tr.att_x_prev_old[col];
4458 scratch.grad_x3[col] += g * (1.0 - mix);
4459 scratch.grad_x2[col] = g * (prev - base);
4460 }
4461 if scope.attn {
4462 match optimizer {
4463 OptimizerKind::Sgd => sgd_vec_update(
4464 block.attn.x_k.as_mut_slice(),
4465 scratch.grad_x2.as_slice(),
4466 lr,
4467 clip,
4468 ),
4469 OptimizerKind::Adam => {
4470 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4471 let adam =
4472 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4473 apply_adam_vec_update(
4474 block.attn.x_k.as_mut_slice(),
4475 scratch.grad_x2.as_slice(),
4476 &mut adam.attn.x_k,
4477 cfg,
4478 );
4479 }
4480 }
4481 }
4482
4483 for col in 0..c {
4485 let g = scratch.grad_x4[col];
4486 let mix = block.attn.x_v[col];
4487 let base = tr.attn_norm[col];
4488 let prev = tr.att_x_prev_old[col];
4489 scratch.grad_x3[col] += g * (1.0 - mix);
4490 scratch.grad_x2[col] = g * (prev - base);
4491 }
4492 if scope.attn {
4493 match optimizer {
4494 OptimizerKind::Sgd => sgd_vec_update(
4495 block.attn.x_v.as_mut_slice(),
4496 scratch.grad_x2.as_slice(),
4497 lr,
4498 clip,
4499 ),
4500 OptimizerKind::Adam => {
4501 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4502 let adam =
4503 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4504 apply_adam_vec_update(
4505 block.attn.x_v.as_mut_slice(),
4506 scratch.grad_x2.as_slice(),
4507 &mut adam.attn.x_v,
4508 cfg,
4509 );
4510 }
4511 }
4512 }
4513
4514 for col in 0..c {
4516 let g = scratch.grad_x5[col];
4517 let mix = block.attn.x_a[col];
4518 let base = tr.attn_norm[col];
4519 let prev = tr.att_x_prev_old[col];
4520 scratch.grad_x3[col] += g * (1.0 - mix);
4521 scratch.grad_x2[col] = g * (prev - base);
4522 }
4523 if scope.attn {
4524 match optimizer {
4525 OptimizerKind::Sgd => sgd_vec_update(
4526 block.attn.x_a.as_mut_slice(),
4527 scratch.grad_x2.as_slice(),
4528 lr,
4529 clip,
4530 ),
4531 OptimizerKind::Adam => {
4532 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4533 let adam =
4534 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4535 apply_adam_vec_update(
4536 block.attn.x_a.as_mut_slice(),
4537 scratch.grad_x2.as_slice(),
4538 &mut adam.attn.x_a,
4539 cfg,
4540 );
4541 }
4542 }
4543 }
4544
4545 for col in 0..c {
4547 let g = scratch.grad_saved[col];
4548 let mix = block.attn.x_g[col];
4549 let base = tr.attn_norm[col];
4550 let prev = tr.att_x_prev_old[col];
4551 scratch.grad_x3[col] += g * (1.0 - mix);
4552 scratch.grad_x2[col] = g * (prev - base);
4553 }
4554 if scope.attn {
4555 match optimizer {
4556 OptimizerKind::Sgd => sgd_vec_update(
4557 block.attn.x_g.as_mut_slice(),
4558 scratch.grad_x2.as_slice(),
4559 lr,
4560 clip,
4561 ),
4562 OptimizerKind::Adam => {
4563 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4564 let adam =
4565 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4566 apply_adam_vec_update(
4567 block.attn.x_g.as_mut_slice(),
4568 scratch.grad_x2.as_slice(),
4569 &mut adam.attn.x_g,
4570 cfg,
4571 );
4572 }
4573 }
4574 }
4575
4576 layer_norm_backward(
4578 tr.x_after_pre.as_slice(),
4579 block.attn_norm_w.as_slice(),
4580 scratch.grad_x3.as_slice(),
4581 self.cfg.layer_norm_eps,
4582 scratch.grad_x2.as_mut_slice(),
4583 scratch.grad_x4.as_mut_slice(),
4584 scratch.grad_x5.as_mut_slice(),
4585 );
4586 if scope.attn_norm {
4587 match optimizer {
4588 OptimizerKind::Sgd => {
4589 sgd_vec_update(
4590 block.attn_norm_w.as_mut_slice(),
4591 scratch.grad_x4.as_slice(),
4592 lr,
4593 clip,
4594 );
4595 sgd_vec_update(
4596 block.attn_norm_b.as_mut_slice(),
4597 scratch.grad_x5.as_slice(),
4598 lr,
4599 clip,
4600 );
4601 }
4602 OptimizerKind::Adam => {
4603 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4604 let adam =
4605 &mut model_adam.as_mut().expect("adam state exists").blocks[layer_idx];
4606 apply_adam_vec_update(
4607 block.attn_norm_w.as_mut_slice(),
4608 scratch.grad_x4.as_slice(),
4609 &mut adam.attn_norm_w,
4610 cfg,
4611 );
4612 apply_adam_vec_update(
4613 block.attn_norm_b.as_mut_slice(),
4614 scratch.grad_x5.as_slice(),
4615 &mut adam.attn_norm_b,
4616 cfg,
4617 );
4618 }
4619 }
4620 }
4621 for col in 0..c {
4622 scratch.grad_x[col] += scratch.grad_x2[col];
4623 }
4624
4625 if layer_idx == 0
4627 && let (Some(w), Some(b)) = (&mut block.pre_norm_w, &mut block.pre_norm_b)
4628 {
4629 layer_norm_backward(
4630 tr.x_in.as_slice(),
4631 w.as_slice(),
4632 scratch.grad_x.as_slice(),
4633 self.cfg.layer_norm_eps,
4634 scratch.grad_x2.as_mut_slice(),
4635 scratch.grad_x3.as_mut_slice(),
4636 scratch.grad_x4.as_mut_slice(),
4637 );
4638 if scope.pre_norm {
4639 match optimizer {
4640 OptimizerKind::Sgd => {
4641 sgd_vec_update(w.as_mut_slice(), scratch.grad_x3.as_slice(), lr, clip);
4642 sgd_vec_update(b.as_mut_slice(), scratch.grad_x4.as_slice(), lr, clip);
4643 }
4644 OptimizerKind::Adam => {
4645 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4646 let adam = &mut model_adam.as_mut().expect("adam state exists").blocks
4647 [layer_idx];
4648 apply_adam_vec_update(
4649 w.as_mut_slice(),
4650 scratch.grad_x3.as_slice(),
4651 adam.pre_norm_w.as_mut().expect("adam pre_norm_w"),
4652 cfg,
4653 );
4654 apply_adam_vec_update(
4655 b.as_mut_slice(),
4656 scratch.grad_x4.as_slice(),
4657 adam.pre_norm_b.as_mut().expect("adam pre_norm_b"),
4658 cfg,
4659 );
4660 }
4661 }
4662 }
4663 scratch.grad_x.copy_from_slice(scratch.grad_x2.as_slice());
4664 }
4665 }
4666
4667 if scope.embed {
4668 let token_idx = scratch
4669 .train_token
4670 .min(self.cfg.vocab_size.saturating_sub(1));
4671 let off = token_idx * c;
4672 let row = &mut self.embeddings.as_mut_slice()[off..off + c];
4673 match optimizer {
4674 OptimizerKind::Sgd => {
4675 sgd_vec_update(row, scratch.grad_x.as_slice(), lr, clip);
4676 }
4677 OptimizerKind::Adam => {
4678 let cfg = adam_step.as_ref().expect("adam cfg initialized");
4679 let adam = model_adam.as_mut().expect("adam state exists");
4680 let m = &mut adam.embeddings.m.as_mut_slice()[off..off + c];
4681 let v = &mut adam.embeddings.v.as_mut_slice()[off..off + c];
4682 apply_adam_vec_update_raw(row, scratch.grad_x.as_slice(), m, v, cfg);
4683 }
4684 }
4685 }
4686 Ok(())
4687 }
4688
4689 #[inline(never)]
4692 pub fn forward<'a>(
4693 &'a self,
4694 scratch: &'a mut ScratchBuffers,
4695 token: u32,
4696 state: &mut State,
4697 ) -> &'a [f32] {
4698 let mut sink = NullProfiler;
4699 self.forward_with_sink(scratch, token, state, &mut sink)
4700 }
4701
4702 #[inline(never)]
4704 pub fn forward_with_profiler<'a, S: ProfilerSink>(
4705 &'a self,
4706 scratch: &'a mut ScratchBuffers,
4707 token: u32,
4708 state: &mut State,
4709 profiler: &mut S,
4710 ) -> &'a [f32] {
4711 self.forward_with_sink(scratch, token, state, profiler)
4712 }
4713
4714 #[inline(never)]
4715 fn forward_with_sink<'a, S: ProfilerSink>(
4716 &'a self,
4717 scratch: &'a mut ScratchBuffers,
4718 token: u32,
4719 state: &mut State,
4720 profiler: &mut S,
4721 ) -> &'a [f32] {
4722 if scratch.capture_train_trace {
4723 self.forward_with_sink_impl::<true, S>(scratch, token, state, profiler)
4724 } else {
4725 self.forward_with_sink_impl::<false, S>(scratch, token, state, profiler)
4726 }
4727 }
4728
4729 fn forward_with_sink_impl<'a, const CAPTURE: bool, S: ProfilerSink>(
4730 &'a self,
4731 scratch: &'a mut ScratchBuffers,
4732 token: u32,
4733 state: &mut State,
4734 profiler: &mut S,
4735 ) -> &'a [f32] {
4736 let c = self.cfg.hidden_size;
4737 let _h = self.cfg.num_heads;
4738 let _n = self.cfg.head_dim;
4739 let num_layers = self.cfg.num_layers;
4740 let token_idx = (token as usize).min(self.cfg.vocab_size.saturating_sub(1));
4741
4742 let emb_offset = token_idx * c;
4744 let emb_slice = &self.embeddings.as_slice()[emb_offset..emb_offset + c];
4745 scratch.x.as_mut_slice().copy_from_slice(emb_slice);
4746 if CAPTURE {
4747 scratch.train_token = token_idx;
4748 scratch.train_trace_valid = true;
4749 } else {
4750 scratch.train_trace_valid = false;
4751 }
4752
4753 profiler.begin_token();
4754
4755 unsafe {
4756 for layer_idx in 0..num_layers {
4758 if CAPTURE {
4759 scratch.train_trace_layers[layer_idx]
4760 .x_in
4761 .copy_from(&scratch.x);
4762 }
4763 if let (Some(w), Some(b)) = (
4765 &self.blocks[layer_idx].pre_norm_w,
4766 &self.blocks[layer_idx].pre_norm_b,
4767 ) {
4768 kernel::layer_norm_avx(
4769 scratch.x.as_ptr(),
4770 w.as_ptr(),
4771 b.as_ptr(),
4772 scratch.x.as_mut_ptr(),
4773 c,
4774 self.cfg.layer_norm_eps,
4775 );
4776 }
4777 if CAPTURE {
4778 scratch.train_trace_layers[layer_idx]
4779 .x_after_pre
4780 .copy_from(&scratch.x);
4781 }
4782
4783 kernel::layer_norm_avx(
4785 scratch.x.as_ptr(),
4786 self.blocks[layer_idx].attn_norm_w.as_ptr(),
4787 self.blocks[layer_idx].attn_norm_b.as_ptr(),
4788 scratch.x_normed.as_mut_ptr(),
4789 c,
4790 self.cfg.layer_norm_eps,
4791 );
4792 if CAPTURE {
4793 scratch.train_trace_layers[layer_idx]
4794 .attn_norm
4795 .copy_from(&scratch.x_normed);
4796 }
4797
4798 let trace_ptr = if CAPTURE {
4799 &mut scratch.train_trace_layers[layer_idx] as *mut LayerTrainTrace
4800 } else {
4801 std::ptr::null_mut()
4802 };
4803 if S::ENABLED {
4804 let attn_start = Instant::now();
4805 self.attention_forward_impl::<CAPTURE>(scratch, layer_idx, state, trace_ptr);
4806 profiler.record_attention(layer_idx, attn_start.elapsed());
4807 } else {
4808 self.attention_forward_impl::<CAPTURE>(scratch, layer_idx, state, trace_ptr);
4809 }
4810
4811 kernel::add_avx(
4813 scratch.x.as_ptr(),
4814 scratch.att_out.as_ptr(),
4815 scratch.x.as_mut_ptr(),
4816 c,
4817 );
4818 if CAPTURE {
4819 scratch.train_trace_layers[layer_idx]
4820 .x_after_attn
4821 .copy_from(&scratch.x);
4822 }
4823
4824 kernel::layer_norm_avx(
4826 scratch.x.as_ptr(),
4827 self.blocks[layer_idx].ffn_norm_w.as_ptr(),
4828 self.blocks[layer_idx].ffn_norm_b.as_ptr(),
4829 scratch.x_normed.as_mut_ptr(),
4830 c,
4831 self.cfg.layer_norm_eps,
4832 );
4833 if CAPTURE {
4834 scratch.train_trace_layers[layer_idx]
4835 .ffn_norm
4836 .copy_from(&scratch.x_normed);
4837 }
4838
4839 if S::ENABLED {
4840 let ffn_start = Instant::now();
4841 self.ffn_forward_impl::<CAPTURE>(
4842 scratch,
4843 layer_idx,
4844 &mut state.layers[layer_idx],
4845 trace_ptr,
4846 );
4847 profiler.record_ffn(layer_idx, ffn_start.elapsed());
4848 } else {
4849 self.ffn_forward_impl::<CAPTURE>(
4850 scratch,
4851 layer_idx,
4852 &mut state.layers[layer_idx],
4853 trace_ptr,
4854 );
4855 }
4856
4857 kernel::add_avx(
4859 scratch.x.as_ptr(),
4860 scratch.ffn_out.as_ptr(),
4861 scratch.x.as_mut_ptr(),
4862 c,
4863 );
4864 if CAPTURE {
4865 scratch.train_trace_layers[layer_idx]
4866 .x_out
4867 .copy_from(&scratch.x);
4868 }
4869 }
4870
4871 kernel::layer_norm_avx(
4873 scratch.x.as_ptr(),
4874 self.ln_out_w.as_ptr(),
4875 self.ln_out_b.as_ptr(),
4876 scratch.x_normed.as_mut_ptr(),
4877 c,
4878 self.cfg.layer_norm_eps,
4879 );
4880
4881 kernel::gemv_avx(
4883 self.lm_head.as_ptr(),
4884 scratch.x_normed.as_ptr(),
4885 scratch.logits.as_mut_ptr(),
4886 self.cfg.vocab_size,
4887 c,
4888 );
4889 }
4890 if CAPTURE {
4891 scratch.train_v_first.copy_from(&state.v_first);
4892 }
4893
4894 scratch.logits.as_slice()
4895 }
4896
4897 #[inline(always)]
4898 unsafe fn attention_forward_impl<const CAPTURE: bool>(
4899 &self,
4900 scratch: &mut ScratchBuffers,
4901 layer_idx: usize,
4902 state: &mut State,
4903 trace: *mut LayerTrainTrace,
4904 ) {
4905 let attn = &self.blocks[layer_idx].attn;
4906 let layer_state = &mut state.layers[layer_idx];
4907 let c = self.cfg.hidden_size;
4908 let h = self.cfg.num_heads;
4909 let n = self.cfg.head_dim;
4910 let d_w = self.cfg.decay_low_rank;
4911 let d_a = self.cfg.a_low_rank;
4912 let d_g = self.cfg.g_low_rank;
4913 if CAPTURE {
4914 let tr = &mut *trace;
4915 tr.att_x_prev_old.copy_from(&layer_state.att_x_prev);
4916 tr.att_state_old.copy_from(&layer_state.att_state);
4917 }
4918
4919 kernel::token_shift_multi6_avx(
4920 scratch.x_normed.as_ptr(),
4921 layer_state.att_x_prev.as_ptr(),
4922 attn.x_r.as_ptr(),
4923 attn.x_w.as_ptr(),
4924 attn.x_k.as_ptr(),
4925 attn.x_v.as_ptr(),
4926 attn.x_a.as_ptr(),
4927 attn.x_g.as_ptr(),
4928 scratch.xr.as_mut_ptr(),
4929 scratch.xw.as_mut_ptr(),
4930 scratch.xk.as_mut_ptr(),
4931 scratch.xv.as_mut_ptr(),
4932 scratch.xa.as_mut_ptr(),
4933 scratch.xg.as_mut_ptr(),
4934 c,
4935 );
4936 if CAPTURE {
4937 let tr = &mut *trace;
4938 tr.xr.copy_from(&scratch.xr);
4939 tr.xw.copy_from(&scratch.xw);
4940 tr.xk.copy_from(&scratch.xk);
4941 tr.xv.copy_from(&scratch.xv);
4942 tr.xa.copy_from(&scratch.xa);
4943 tr.xg.copy_from(&scratch.xg);
4944 }
4945
4946 kernel::copy(
4948 scratch.x_normed.as_ptr(),
4949 layer_state.att_x_prev.as_mut_ptr(),
4950 c,
4951 );
4952
4953 let proj_size = c * c;
4956 kernel::gemv_avx(
4957 attn.rkv_proj.as_ptr(),
4958 scratch.xr.as_ptr(),
4959 scratch.r.as_mut_ptr(),
4960 c,
4961 c,
4962 );
4963 kernel::gemv_avx(
4964 attn.rkv_proj.as_ptr().add(proj_size),
4965 scratch.xk.as_ptr(),
4966 scratch.k.as_mut_ptr(),
4967 c,
4968 c,
4969 );
4970 kernel::gemv_avx(
4971 attn.rkv_proj.as_ptr().add(2 * proj_size),
4972 scratch.xv.as_ptr(),
4973 scratch.v.as_mut_ptr(),
4974 c,
4975 c,
4976 );
4977 if CAPTURE {
4978 let tr = &mut *trace;
4979 tr.r.copy_from(&scratch.r);
4980 tr.k_pre.copy_from(&scratch.k);
4981 tr.v_pre.copy_from(&scratch.v);
4982 }
4983
4984 kernel::gemv_avx(
4987 attn.w1.as_ptr(),
4988 scratch.xw.as_ptr(),
4989 scratch.w_lora_tmp.as_mut_ptr(),
4990 d_w,
4991 c,
4992 );
4993 kernel::tanh_avx(
4995 scratch.w_lora_tmp.as_ptr(),
4996 scratch.w_lora_tmp.as_mut_ptr(),
4997 d_w,
4998 );
4999 if CAPTURE {
5000 let tr = &mut *trace;
5001 tr.w_hidden.as_mut_slice()[0..d_w]
5002 .copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_w]);
5003 }
5004 kernel::gemv_avx(
5006 attn.w2.as_ptr(),
5007 scratch.w_lora_tmp.as_ptr(),
5008 scratch.w_decay.as_mut_ptr(),
5009 c,
5010 d_w,
5011 );
5012 kernel::add_avx(
5014 scratch.w_decay.as_ptr(),
5015 attn.w0.as_ptr(),
5016 scratch.w_decay.as_mut_ptr(),
5017 c,
5018 );
5019 if CAPTURE {
5020 let tr = &mut *trace;
5021 tr.w_pre.copy_from(&scratch.w_decay);
5022 }
5023 let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
5025 kernel::sigmoid_exp_neg_scaled_avx(
5026 scratch.w_decay.as_ptr(),
5027 scratch.w_decay.as_mut_ptr(),
5028 if CAPTURE {
5029 (*trace).w_sigmoid.as_mut_ptr()
5030 } else {
5031 std::ptr::null_mut()
5032 },
5033 inv_sqrt_e,
5034 c,
5035 );
5036 if CAPTURE {
5037 let tr = &mut *trace;
5038 tr.w_decay.copy_from(&scratch.w_decay);
5039 }
5040
5041 kernel::gemv_avx(
5043 attn.a1.as_ptr(),
5044 scratch.xa.as_ptr(),
5045 scratch.w_lora_tmp.as_mut_ptr(),
5046 d_a,
5047 c,
5048 );
5049 if CAPTURE {
5050 let tr = &mut *trace;
5051 tr.a_hidden.as_mut_slice()[0..d_a]
5052 .copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_a]);
5053 }
5054 kernel::gemv_avx(
5055 attn.a2.as_ptr(),
5056 scratch.w_lora_tmp.as_ptr(),
5057 scratch.a.as_mut_ptr(),
5058 c,
5059 d_a,
5060 );
5061 kernel::add_avx(
5062 scratch.a.as_ptr(),
5063 attn.a0.as_ptr(),
5064 scratch.a.as_mut_ptr(),
5065 c,
5066 );
5067 kernel::sigmoid_avx(scratch.a.as_ptr(), scratch.a.as_mut_ptr(), c);
5068 if CAPTURE {
5069 let tr = &mut *trace;
5070 tr.a.copy_from(&scratch.a);
5071 }
5072
5073 kernel::gemv_avx(
5075 attn.g1.as_ptr(),
5076 scratch.xg.as_ptr(),
5077 scratch.w_lora_tmp.as_mut_ptr(),
5078 d_g,
5079 c,
5080 );
5081 kernel::sigmoid_avx(
5082 scratch.w_lora_tmp.as_ptr(),
5083 scratch.w_lora_tmp.as_mut_ptr(),
5084 d_g,
5085 );
5086 if CAPTURE {
5087 let tr = &mut *trace;
5088 tr.g_hidden.as_mut_slice()[0..d_g]
5089 .copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_g]);
5090 }
5091 kernel::gemv_avx(
5092 attn.g2.as_ptr(),
5093 scratch.w_lora_tmp.as_ptr(),
5094 scratch.g.as_mut_ptr(),
5095 c,
5096 d_g,
5097 );
5098 if CAPTURE {
5099 let tr = &mut *trace;
5100 tr.g.copy_from(&scratch.g);
5101 }
5102
5103 if layer_idx == 0 {
5105 state.v_first.copy_from(&scratch.v);
5107 state.v_first_set = true;
5108 if CAPTURE {
5109 let tr = &mut *trace;
5110 tr.uses_v_residual = false;
5111 tr.v.copy_from(&scratch.v);
5112 }
5113 } else if state.v_first_set
5114 && let (Some(v1), Some(v2), Some(v0)) = (&attn.v1, &attn.v2, &attn.v0)
5115 {
5116 let d_v = self.cfg.v_low_rank;
5117 kernel::gemv_avx(
5119 v1.as_ptr(),
5120 scratch.xv.as_ptr(),
5121 scratch.w_lora_tmp.as_mut_ptr(),
5122 d_v,
5123 c,
5124 );
5125 if CAPTURE {
5126 let tr = &mut *trace;
5127 tr.v_hidden.as_mut_slice()[0..d_v]
5128 .copy_from_slice(&scratch.w_lora_tmp.as_slice()[0..d_v]);
5129 }
5130 kernel::gemv_avx(
5131 v2.as_ptr(),
5132 scratch.w_lora_tmp.as_ptr(),
5133 scratch.att_out.as_mut_ptr(), c,
5135 d_v,
5136 );
5137 kernel::add_avx(
5138 scratch.att_out.as_ptr(),
5139 v0.as_ptr(),
5140 scratch.att_out.as_mut_ptr(),
5141 c,
5142 );
5143 kernel::sigmoid_avx(scratch.att_out.as_ptr(), scratch.att_out.as_mut_ptr(), c);
5144 if CAPTURE {
5145 let tr = &mut *trace;
5146 tr.uses_v_residual = true;
5147 tr.nu.copy_from(&scratch.att_out);
5148 }
5149 for i in 0..c {
5151 let nu = scratch.att_out[i];
5152 scratch.v[i] += (state.v_first[i] - scratch.v[i]) * nu;
5153 }
5154 if CAPTURE {
5155 let tr = &mut *trace;
5156 tr.v.copy_from(&scratch.v);
5157 }
5158 } else if CAPTURE {
5159 let tr = &mut *trace;
5160 tr.uses_v_residual = false;
5161 tr.v.copy_from(&scratch.v);
5162 }
5163
5164 kernel::mul_avx(
5166 scratch.k.as_ptr(),
5167 attn.k_k.as_ptr(),
5168 scratch.kk.as_mut_ptr(),
5169 c,
5170 );
5171 if CAPTURE {
5172 let tr = &mut *trace;
5173 tr.kk_pre.copy_from(&scratch.kk);
5174 }
5175 for head in 0..h {
5177 let offset = head * n;
5178 kernel::l2_normalize_avx(
5179 scratch.kk.as_ptr().add(offset),
5180 scratch.kk.as_mut_ptr().add(offset),
5181 n,
5182 1e-12,
5183 );
5184 }
5185 if CAPTURE {
5186 let tr = &mut *trace;
5187 tr.kk.copy_from(&scratch.kk);
5188 }
5189
5190 for i in 0..c {
5192 let scale = 1.0 + (scratch.a[i] - 1.0) * attn.k_a[i];
5193 scratch.k[i] *= scale;
5194 }
5195 if CAPTURE {
5196 let tr = &mut *trace;
5197 tr.k.copy_from(&scratch.k);
5198 }
5199
5200 kernel::rwkv7_wkv_update_avx(
5202 layer_state.att_state.as_mut_ptr(),
5203 scratch.w_decay.as_ptr(),
5204 scratch.k.as_ptr(),
5205 scratch.v.as_ptr(),
5206 scratch.kk.as_ptr(),
5207 scratch.a.as_ptr(),
5208 scratch.r.as_ptr(),
5209 scratch.y.as_mut_ptr(),
5210 h,
5211 n,
5212 );
5213 if CAPTURE {
5214 let tr = &mut *trace;
5215 tr.y_wkv.copy_from(&scratch.y);
5216 }
5217
5218 kernel::group_norm_avx(
5220 scratch.y.as_ptr(),
5221 attn.g_norm_w.as_ptr(),
5222 attn.g_norm_b.as_ptr(),
5223 scratch.y.as_mut_ptr(),
5224 h,
5225 n,
5226 self.cfg.group_norm_eps,
5227 );
5228 if CAPTURE {
5229 let tr = &mut *trace;
5230 tr.y_gn.copy_from(&scratch.y);
5231 }
5232
5233 for head in 0..h {
5235 let offset = head * n;
5236 let mut alpha = 0.0f32;
5237 for j in 0..n {
5238 alpha += scratch.r[offset + j] * scratch.k[offset + j] * attn.r_k[head * n + j];
5239 }
5240 if CAPTURE {
5241 let tr = &mut *trace;
5242 tr.alpha[head] = alpha;
5243 }
5244 for j in 0..n {
5245 scratch.y[offset + j] += alpha * scratch.v[offset + j];
5246 }
5247 }
5248 if CAPTURE {
5249 let tr = &mut *trace;
5250 tr.y_head.copy_from(&scratch.y);
5251 }
5252
5253 kernel::mul_avx(
5255 scratch.y.as_ptr(),
5256 scratch.g.as_ptr(),
5257 scratch.y.as_mut_ptr(),
5258 c,
5259 );
5260 if CAPTURE {
5261 let tr = &mut *trace;
5262 tr.y_gate.copy_from(&scratch.y);
5263 }
5264
5265 kernel::gemv_avx(
5267 attn.o_proj.as_ptr(),
5268 scratch.y.as_ptr(),
5269 scratch.att_out.as_mut_ptr(),
5270 c,
5271 c,
5272 );
5273 if CAPTURE {
5274 let tr = &mut *trace;
5275 tr.att_out.copy_from(&scratch.att_out);
5276 }
5277 }
5278
5279 #[inline(always)]
5280 unsafe fn ffn_forward_impl<const CAPTURE: bool>(
5281 &self,
5282 scratch: &mut ScratchBuffers,
5283 layer_idx: usize,
5284 layer_state: &mut LayerState,
5285 trace: *mut LayerTrainTrace,
5286 ) {
5287 let ffn = &self.blocks[layer_idx].ffn;
5288 let c = self.cfg.hidden_size;
5289 let i = self.cfg.intermediate_size;
5290 if CAPTURE {
5291 let tr = &mut *trace;
5292 tr.ffn_x_prev_old.copy_from(&layer_state.ffn_x_prev);
5293 }
5294
5295 kernel::token_shift_avx(
5297 scratch.x_normed.as_ptr(),
5298 layer_state.ffn_x_prev.as_ptr(),
5299 ffn.x_k.as_ptr(),
5300 scratch.xk.as_mut_ptr(),
5301 c,
5302 );
5303 if CAPTURE {
5304 let tr = &mut *trace;
5305 tr.ffn_xk.copy_from(&scratch.xk);
5306 }
5307
5308 kernel::copy(
5310 scratch.x_normed.as_ptr(),
5311 layer_state.ffn_x_prev.as_mut_ptr(),
5312 c,
5313 );
5314
5315 kernel::gemv_avx(
5317 ffn.key_w.as_ptr(),
5318 scratch.xk.as_ptr(),
5319 scratch.ffn_k.as_mut_ptr(),
5320 i,
5321 c,
5322 );
5323 if CAPTURE {
5324 let tr = &mut *trace;
5325 tr.ffn_pre.copy_from(&scratch.ffn_k);
5326 }
5327 kernel::relu_squared_avx(scratch.ffn_k.as_ptr(), scratch.ffn_k.as_mut_ptr(), i);
5328 if CAPTURE {
5329 let tr = &mut *trace;
5330 tr.ffn_k.copy_from(&scratch.ffn_k);
5331 }
5332
5333 kernel::gemv_avx(
5335 ffn.value_w.as_ptr(),
5336 scratch.ffn_k.as_ptr(),
5337 scratch.ffn_out.as_mut_ptr(),
5338 c,
5339 i,
5340 );
5341 if CAPTURE {
5342 let tr = &mut *trace;
5343 tr.ffn_out.copy_from(&scratch.ffn_out);
5344 }
5345 }
5346}
5347
5348#[allow(clippy::needless_range_loop)]
5349fn layer_norm_backward(
5350 input: &[f32],
5351 weight: &[f32],
5352 grad_out: &[f32],
5353 eps: f32,
5354 grad_input: &mut [f32],
5355 grad_weight: &mut [f32],
5356 grad_bias: &mut [f32],
5357) {
5358 let n = input
5359 .len()
5360 .min(weight.len())
5361 .min(grad_out.len())
5362 .min(grad_input.len())
5363 .min(grad_weight.len())
5364 .min(grad_bias.len());
5365 if n == 0 {
5366 return;
5367 }
5368 let nf = n as f32;
5369 let mut mean = 0.0f32;
5370 for &x in &input[0..n] {
5371 mean += x;
5372 }
5373 mean /= nf;
5374 let mut var = 0.0f32;
5375 for &x in &input[0..n] {
5376 let d = x - mean;
5377 var += d * d;
5378 }
5379 var /= nf;
5380 let inv_std = (var + eps).sqrt().recip();
5381 let mut sum_gw = 0.0f32;
5382 let mut sum_gw_xhat = 0.0f32;
5383 for i in 0..n {
5384 let xhat = (input[i] - mean) * inv_std;
5385 let gw = grad_out[i] * weight[i];
5386 grad_weight[i] = grad_out[i] * xhat;
5387 grad_bias[i] = grad_out[i];
5388 sum_gw += gw;
5389 sum_gw_xhat += gw * xhat;
5390 }
5391 for i in 0..n {
5392 let xhat = (input[i] - mean) * inv_std;
5393 let gw = grad_out[i] * weight[i];
5394 grad_input[i] = (gw * nf - sum_gw - xhat * sum_gw_xhat) * inv_std / nf;
5395 }
5396}
5397
5398#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
5399fn group_norm_backward(
5400 input: &[f32],
5401 weight: &[f32],
5402 grad_out: &[f32],
5403 num_groups: usize,
5404 group_size: usize,
5405 eps: f32,
5406 grad_input: &mut [f32],
5407 grad_weight: &mut [f32],
5408 grad_bias: &mut [f32],
5409) {
5410 let c = input
5411 .len()
5412 .min(weight.len())
5413 .min(grad_out.len())
5414 .min(grad_input.len())
5415 .min(grad_weight.len())
5416 .min(grad_bias.len());
5417 if c == 0 || num_groups == 0 || group_size == 0 {
5418 return;
5419 }
5420 grad_input[0..c].fill(0.0);
5421 grad_weight[0..c].fill(0.0);
5422 grad_bias[0..c].fill(0.0);
5423 let g = num_groups.min(c / group_size);
5424 let n = group_size as f32;
5425 for group in 0..g {
5426 let off = group * group_size;
5427 let end = (off + group_size).min(c);
5428 let len = end - off;
5429 if len == 0 {
5430 continue;
5431 }
5432 let mut mean = 0.0f32;
5433 for idx in off..end {
5434 mean += input[idx];
5435 }
5436 mean /= len as f32;
5437 let mut var = 0.0f32;
5438 for idx in off..end {
5439 let d = input[idx] - mean;
5440 var += d * d;
5441 }
5442 var /= len as f32;
5443 let inv_std = (var + eps).sqrt().recip();
5444 let mut sum_gw = 0.0f32;
5445 let mut sum_gw_xhat = 0.0f32;
5446 for idx in off..end {
5447 let xhat = (input[idx] - mean) * inv_std;
5448 let gw = grad_out[idx] * weight[idx];
5449 grad_weight[idx] += grad_out[idx] * xhat;
5450 grad_bias[idx] += grad_out[idx];
5451 sum_gw += gw;
5452 sum_gw_xhat += gw * xhat;
5453 }
5454 for idx in off..end {
5455 let xhat = (input[idx] - mean) * inv_std;
5456 let gw = grad_out[idx] * weight[idx];
5457 grad_input[idx] = (gw * n - sum_gw - xhat * sum_gw_xhat) * inv_std / n;
5458 }
5459 }
5460}
5461
5462fn l2_normalize_backward(
5463 x: &[f32],
5464 y: &[f32],
5465 grad_out: &[f32],
5466 min_norm: f32,
5467 grad_input: &mut [f32],
5468) {
5469 let n = x
5470 .len()
5471 .min(y.len())
5472 .min(grad_out.len())
5473 .min(grad_input.len());
5474 if n == 0 {
5475 return;
5476 }
5477 let mut norm_sq = 0.0f32;
5478 for &v in &x[0..n] {
5479 norm_sq += v * v;
5480 }
5481 let norm_raw = norm_sq.sqrt();
5482 if norm_raw <= min_norm {
5483 let inv = min_norm.recip();
5484 for i in 0..n {
5485 grad_input[i] = grad_out[i] * inv;
5486 }
5487 return;
5488 }
5489 let norm = norm_raw;
5490 let mut dot = 0.0f32;
5491 for i in 0..n {
5492 dot += grad_out[i] * y[i];
5493 }
5494 let inv = norm.recip();
5495 for i in 0..n {
5496 grad_input[i] = (grad_out[i] - y[i] * dot) * inv;
5497 }
5498}
5499
5500#[inline(always)]
5501fn add_vec_grad(dst: &mut [f32], src: &[f32]) {
5502 let n = dst.len().min(src.len());
5503 for i in 0..n {
5504 dst[i] += src[i];
5505 }
5506}
5507
5508#[inline(always)]
5509#[allow(clippy::needless_range_loop)]
5510fn add_outer_grad(dst: &mut [f32], rows: usize, cols: usize, left: &[f32], right: &[f32]) {
5511 let rows = rows.min(left.len());
5512 let cols = cols.min(right.len());
5513 let n = dst.len();
5514 if rows == 0 || cols == 0 || n == 0 {
5515 return;
5516 }
5517 for r in 0..rows {
5518 let g = left[r];
5519 if g == 0.0 {
5520 continue;
5521 }
5522 let off = r * cols;
5523 if off >= n {
5524 break;
5525 }
5526 let row_cols = cols.min(n - off);
5527 for c in 0..row_cols {
5528 dst[off + c] += g * right[c];
5529 }
5530 }
5531}
5532
5533#[inline(always)]
5534fn sgd_vec_update(param: &mut [f32], grad: &[f32], lr: f32, clip: f32) {
5535 let n = param.len().min(grad.len());
5536 if n == 0 {
5537 return;
5538 }
5539 if clip > 0.0 {
5540 for i in 0..n {
5541 param[i] += lr * grad[i].clamp(-clip, clip);
5542 }
5543 } else {
5544 for i in 0..n {
5545 param[i] += lr * grad[i];
5546 }
5547 }
5548}
5549
5550#[inline(always)]
5551#[allow(clippy::needless_range_loop)]
5552fn sgd_outer_update(
5553 param: &mut [f32],
5554 rows: usize,
5555 cols: usize,
5556 left: &[f32],
5557 right: &[f32],
5558 lr: f32,
5559 clip: f32,
5560) {
5561 let rows = rows.min(left.len());
5562 let cols = cols.min(right.len());
5563 let n = param.len();
5564 if rows == 0 || cols == 0 || n == 0 {
5565 return;
5566 }
5567 for r in 0..rows {
5568 let g = left[r];
5569 let off = r * cols;
5570 if off >= n {
5571 break;
5572 }
5573 let row_cols = cols.min(n - off);
5574 if clip > 0.0 {
5575 for c in 0..row_cols {
5576 param[off + c] += lr * (g * right[c]).clamp(-clip, clip);
5577 }
5578 } else {
5579 for c in 0..row_cols {
5580 param[off + c] += lr * g * right[c];
5581 }
5582 }
5583 }
5584}
5585
5586#[inline(always)]
5587#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
5588fn fused_sgd_head_backward_update(
5589 param: &mut [f32],
5590 rows: usize,
5591 cols: usize,
5592 left: &[f32],
5593 right: &[f32],
5594 grad_input: &mut [f32],
5595 lr: f32,
5596 clip: f32,
5597) {
5598 let rows = rows.min(left.len());
5599 let cols = cols.min(right.len()).min(grad_input.len());
5600 let n = param.len();
5601 if rows == 0 || cols == 0 || n == 0 {
5602 return;
5603 }
5604 let do_clip = clip > 0.0;
5605 let lr8 = f32x8::splat(lr);
5606 for row in 0..rows {
5607 let g = left[row];
5608 if g == 0.0 {
5609 continue;
5610 }
5611 let off = row * cols;
5612 if off >= n {
5613 break;
5614 }
5615 let row_cols = cols.min(n - off);
5616 if do_clip {
5617 for col in 0..row_cols {
5618 let idx = off + col;
5619 let w_old = param[idx];
5620 grad_input[col] += w_old * g;
5621 param[idx] = w_old + lr * (g * right[col]).clamp(-clip, clip);
5622 }
5623 continue;
5624 }
5625 let mut col = 0usize;
5626 unsafe {
5627 let g8 = f32x8::splat(g);
5628 while col + 8 <= row_cols {
5629 let idx = off + col;
5630 let wv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5631 let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
5632 let giv = grad_input
5633 .as_ptr()
5634 .add(col)
5635 .cast::<f32x8>()
5636 .read_unaligned();
5637 grad_input
5638 .as_mut_ptr()
5639 .add(col)
5640 .cast::<f32x8>()
5641 .write_unaligned(giv + wv * g8);
5642 param
5643 .as_mut_ptr()
5644 .add(idx)
5645 .cast::<f32x8>()
5646 .write_unaligned(wv + (g8 * rv) * lr8);
5647 col += 8;
5648 }
5649 }
5650 while col < row_cols {
5651 let idx = off + col;
5652 let w_old = param[idx];
5653 grad_input[col] += w_old * g;
5654 param[idx] = w_old + lr * g * right[col];
5655 col += 1;
5656 }
5657 }
5658}
5659
5660#[inline(always)]
5661fn apply_adam_vec_update(
5662 param: &mut [f32],
5663 grad: &[f32],
5664 adam: &mut AdamTensorState,
5665 step: &AdamStep,
5666) {
5667 let n = param
5668 .len()
5669 .min(grad.len())
5670 .min(adam.m.len())
5671 .min(adam.v.len());
5672 if n == 0 {
5673 return;
5674 }
5675 apply_adam_vec_update_raw(
5676 &mut param[0..n],
5677 &grad[0..n],
5678 &mut adam.m.as_mut_slice()[0..n],
5679 &mut adam.v.as_mut_slice()[0..n],
5680 step,
5681 );
5682}
5683
5684#[inline(always)]
5685fn apply_adam_vec_update_raw(
5686 param: &mut [f32],
5687 grad: &[f32],
5688 m: &mut [f32],
5689 v: &mut [f32],
5690 step: &AdamStep,
5691) {
5692 let n = param.len().min(grad.len()).min(m.len()).min(v.len());
5693 if n == 0 {
5694 return;
5695 }
5696 let b1 = step.b1;
5697 let b2 = step.b2;
5698 let one_b1 = 1.0 - b1;
5699 let one_b2 = 1.0 - b2;
5700 let inv_bc1 = 1.0 / step.bias_corr1;
5701 let inv_bc2 = 1.0 / step.bias_corr2;
5702 let do_clip = step.clip > 0.0;
5703 let clip = step.clip;
5704 if do_clip {
5705 for idx in 0..n {
5706 let g = grad[idx].clamp(-clip, clip);
5707 let mm = b1 * m[idx] + one_b1 * g;
5708 let vv = b2 * v[idx] + one_b2 * g * g;
5709 m[idx] = mm;
5710 v[idx] = vv;
5711 let m_hat = mm * inv_bc1;
5712 let v_hat = vv * inv_bc2;
5713 param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
5714 }
5715 return;
5716 }
5717 let mut idx = 0usize;
5718 unsafe {
5719 let b1v = f32x8::splat(b1);
5720 let b2v = f32x8::splat(b2);
5721 let one_b1v = f32x8::splat(one_b1);
5722 let one_b2v = f32x8::splat(one_b2);
5723 let inv_bc1v = f32x8::splat(inv_bc1);
5724 let inv_bc2v = f32x8::splat(inv_bc2);
5725 let lrv = f32x8::splat(step.lr);
5726 let epsv = f32x8::splat(step.eps);
5727 while idx + 8 <= n {
5728 let gv = grad.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5729 let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5730 let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5731 let mm = mv * b1v + gv * one_b1v;
5732 let vv2 = vv * b2v + (gv * gv) * one_b2v;
5733 m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
5734 v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
5735 let pv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5736 let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
5737 param
5738 .as_mut_ptr()
5739 .add(idx)
5740 .cast::<f32x8>()
5741 .write_unaligned(pv + upd);
5742 idx += 8;
5743 }
5744 }
5745 while idx < n {
5746 let g = grad[idx];
5747 let mm = b1 * m[idx] + one_b1 * g;
5748 let vv = b2 * v[idx] + one_b2 * g * g;
5749 m[idx] = mm;
5750 v[idx] = vv;
5751 let m_hat = mm * inv_bc1;
5752 let v_hat = vv * inv_bc2;
5753 param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
5754 idx += 1;
5755 }
5756}
5757
5758#[inline(always)]
5759#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
5760fn fused_adam_head_backward_update(
5761 param: &mut [f32],
5762 rows: usize,
5763 cols: usize,
5764 left: &[f32],
5765 right: &[f32],
5766 grad_input: &mut [f32],
5767 m: &mut [f32],
5768 v: &mut [f32],
5769 step: &AdamStep,
5770) {
5771 let rows = rows.min(left.len());
5772 let cols = cols.min(right.len()).min(grad_input.len());
5773 let n = param.len().min(m.len()).min(v.len());
5774 if rows == 0 || cols == 0 || n == 0 {
5775 return;
5776 }
5777 let b1 = step.b1;
5778 let b2 = step.b2;
5779 let one_b1 = 1.0 - b1;
5780 let one_b2 = 1.0 - b2;
5781 let inv_bc1 = 1.0 / step.bias_corr1;
5782 let inv_bc2 = 1.0 / step.bias_corr2;
5783 let do_clip = step.clip > 0.0;
5784 let clip = step.clip;
5785 let b1v = f32x8::splat(b1);
5786 let b2v = f32x8::splat(b2);
5787 let one_b1v = f32x8::splat(one_b1);
5788 let one_b2v = f32x8::splat(one_b2);
5789 let inv_bc1v = f32x8::splat(inv_bc1);
5790 let inv_bc2v = f32x8::splat(inv_bc2);
5791 let epsv = f32x8::splat(step.eps);
5792 let lrv = f32x8::splat(step.lr);
5793 for row in 0..rows {
5794 let g = left[row];
5795 if g == 0.0 {
5796 continue;
5797 }
5798 let off = row * cols;
5799 if off >= n {
5800 break;
5801 }
5802 let row_cols = cols.min(n - off);
5803 if do_clip {
5804 for col in 0..row_cols {
5805 let idx = off + col;
5806 let w_old = param[idx];
5807 grad_input[col] += w_old * g;
5808 let gg = (g * right[col]).clamp(-clip, clip);
5809 let mm = b1 * m[idx] + one_b1 * gg;
5810 let vv = b2 * v[idx] + one_b2 * gg * gg;
5811 m[idx] = mm;
5812 v[idx] = vv;
5813 let m_hat = mm * inv_bc1;
5814 let v_hat = vv * inv_bc2;
5815 param[idx] = w_old + step.lr * m_hat / (v_hat.sqrt() + step.eps);
5816 }
5817 continue;
5818 }
5819 let mut col = 0usize;
5820 unsafe {
5821 let g8 = f32x8::splat(g);
5822 while col + 8 <= row_cols {
5823 let idx = off + col;
5824 let wv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5825 let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
5826 let giv = grad_input
5827 .as_ptr()
5828 .add(col)
5829 .cast::<f32x8>()
5830 .read_unaligned();
5831 grad_input
5832 .as_mut_ptr()
5833 .add(col)
5834 .cast::<f32x8>()
5835 .write_unaligned(giv + wv * g8);
5836 let gv = g8 * rv;
5837 let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5838 let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5839 let mm = mv * b1v + gv * one_b1v;
5840 let vv2 = vv * b2v + (gv * gv) * one_b2v;
5841 m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
5842 v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
5843 let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
5844 param
5845 .as_mut_ptr()
5846 .add(idx)
5847 .cast::<f32x8>()
5848 .write_unaligned(wv + upd);
5849 col += 8;
5850 }
5851 }
5852 while col < row_cols {
5853 let idx = off + col;
5854 let w_old = param[idx];
5855 grad_input[col] += w_old * g;
5856 let gg = g * right[col];
5857 let mm = b1 * m[idx] + one_b1 * gg;
5858 let vv = b2 * v[idx] + one_b2 * gg * gg;
5859 m[idx] = mm;
5860 v[idx] = vv;
5861 let m_hat = mm * inv_bc1;
5862 let v_hat = vv * inv_bc2;
5863 param[idx] = w_old + step.lr * m_hat / (v_hat.sqrt() + step.eps);
5864 col += 1;
5865 }
5866 }
5867}
5868
5869#[inline(always)]
5870fn apply_adam_outer_update(
5871 param: &mut [f32],
5872 rows: usize,
5873 cols: usize,
5874 left: &[f32],
5875 right: &[f32],
5876 adam: &mut AdamTensorState,
5877 step: &AdamStep,
5878) {
5879 let n = param.len().min(adam.m.len()).min(adam.v.len());
5880 if n == 0 {
5881 return;
5882 }
5883 apply_adam_outer_update_raw(
5884 &mut param[0..n],
5885 rows,
5886 cols,
5887 left,
5888 right,
5889 &mut adam.m.as_mut_slice()[0..n],
5890 &mut adam.v.as_mut_slice()[0..n],
5891 step,
5892 );
5893}
5894
5895#[allow(clippy::too_many_arguments)]
5896#[inline(always)]
5897#[allow(clippy::needless_range_loop)]
5898fn apply_adam_outer_update_raw(
5899 param: &mut [f32],
5900 rows: usize,
5901 cols: usize,
5902 left: &[f32],
5903 right: &[f32],
5904 m: &mut [f32],
5905 v: &mut [f32],
5906 step: &AdamStep,
5907) {
5908 let rows = rows.min(left.len());
5909 let cols = cols.min(right.len());
5910 let n = param.len().min(m.len()).min(v.len());
5911 if rows == 0 || cols == 0 || n == 0 {
5912 return;
5913 }
5914 let b1 = step.b1;
5915 let b2 = step.b2;
5916 let one_b1 = 1.0 - b1;
5917 let one_b2 = 1.0 - b2;
5918 let inv_bc1 = 1.0 / step.bias_corr1;
5919 let inv_bc2 = 1.0 / step.bias_corr2;
5920 let do_clip = step.clip > 0.0;
5921 let clip = step.clip;
5922 let b1v = f32x8::splat(b1);
5923 let b2v = f32x8::splat(b2);
5924 let one_b1v = f32x8::splat(one_b1);
5925 let one_b2v = f32x8::splat(one_b2);
5926 let inv_bc1v = f32x8::splat(inv_bc1);
5927 let inv_bc2v = f32x8::splat(inv_bc2);
5928 let epsv = f32x8::splat(step.eps);
5929 let lrv = f32x8::splat(step.lr);
5930 for row in 0..rows {
5931 let g_row = left[row];
5932 let off = row * cols;
5933 if off >= n {
5934 break;
5935 }
5936 let row_cols = (n - off).min(cols);
5937 if do_clip {
5938 for col in 0..row_cols {
5939 let idx = off + col;
5940 let g = (g_row * right[col]).clamp(-clip, clip);
5941 let mm = b1 * m[idx] + one_b1 * g;
5942 let vv = b2 * v[idx] + one_b2 * g * g;
5943 m[idx] = mm;
5944 v[idx] = vv;
5945 let m_hat = mm * inv_bc1;
5946 let v_hat = vv * inv_bc2;
5947 param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
5948 }
5949 continue;
5950 }
5951 let mut col = 0usize;
5952 unsafe {
5953 let g8 = f32x8::splat(g_row);
5954 while col + 8 <= row_cols {
5955 let idx = off + col;
5956 let rv = right.as_ptr().add(col).cast::<f32x8>().read_unaligned();
5957 let gv = g8 * rv;
5958 let mv = m.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5959 let vv = v.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5960 let mm = mv * b1v + gv * one_b1v;
5961 let vv2 = vv * b2v + (gv * gv) * one_b2v;
5962 m.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(mm);
5963 v.as_mut_ptr().add(idx).cast::<f32x8>().write_unaligned(vv2);
5964 let pv = param.as_ptr().add(idx).cast::<f32x8>().read_unaligned();
5965 let upd = ((mm * inv_bc1v) / ((vv2 * inv_bc2v).sqrt() + epsv)) * lrv;
5966 param
5967 .as_mut_ptr()
5968 .add(idx)
5969 .cast::<f32x8>()
5970 .write_unaligned(pv + upd);
5971 col += 8;
5972 }
5973 }
5974 while col < row_cols {
5975 let idx = off + col;
5976 let g = g_row * right[col];
5977 let mm = b1 * m[idx] + one_b1 * g;
5978 let vv = b2 * v[idx] + one_b2 * g * g;
5979 m[idx] = mm;
5980 v[idx] = vv;
5981 let m_hat = mm * inv_bc1;
5982 let v_hat = vv * inv_bc2;
5983 param[idx] += step.lr * m_hat / (v_hat.sqrt() + step.eps);
5984 col += 1;
5985 }
5986 }
5987}
5988
5989struct RwkvRng {
5990 state: u64,
5991}
5992
5993impl RwkvRng {
5994 fn new(seed: u64) -> Self {
5995 Self {
5996 state: seed ^ 0x9E37_79B9_7F4A_7C15,
5997 }
5998 }
5999
6000 #[inline]
6001 fn next_u32(&mut self) -> u32 {
6002 self.state = self
6003 .state
6004 .wrapping_mul(6_364_136_223_846_793_005)
6005 .wrapping_add(1);
6006 (self.state >> 32) as u32
6007 }
6008
6009 #[inline]
6010 fn next_f32(&mut self) -> f32 {
6011 let v = self.next_u32() as f32;
6012 v * (1.0 / (u32::MAX as f32))
6013 }
6014}
6015
6016#[inline]
6017fn init_uniform(t: &mut Tensor1D, rng: &mut RwkvRng, scale: f32) {
6018 let s = t.as_mut_slice();
6019 for v in s {
6020 let r = rng.next_f32() - 0.5;
6021 *v = r * 2.0 * scale;
6022 }
6023}
6024
6025#[inline]
6026fn init_centered(t: &mut Tensor1D, rng: &mut RwkvRng, center: f32, scale: f32) {
6027 let s = t.as_mut_slice();
6028 for v in s {
6029 let r = rng.next_f32() - 0.5;
6030 *v = center + r * 2.0 * scale;
6031 }
6032}
6033
6034#[inline]
6035fn init_const(t: &mut Tensor1D, value: f32) {
6036 t.as_mut_slice().fill(value);
6037}
6038
6039#[cfg(test)]
6040mod tests {
6041 use super::*;
6042
6043 fn test_cfg() -> Config {
6044 Config {
6045 vocab_size: 256,
6046 hidden_size: 64,
6047 num_layers: 1,
6048 num_heads: 1,
6049 head_dim: 64,
6050 intermediate_size: 64,
6051 layer_norm_eps: 1e-5,
6052 group_norm_eps: 64e-5,
6053 decay_low_rank: 8,
6054 a_low_rank: 8,
6055 v_low_rank: 8,
6056 g_low_rank: 8,
6057 }
6058 }
6059
6060 fn softmax_loss(logits: &[f32], target: u8) -> f64 {
6061 let max_logit = logits
6062 .iter()
6063 .copied()
6064 .fold(f32::NEG_INFINITY, |a, b| a.max(b));
6065 let mut sum = 0.0f64;
6066 for &z in logits {
6067 sum += ((z - max_logit) as f64).exp();
6068 }
6069 let p = ((logits[target as usize] - max_logit) as f64).exp() / sum.max(1e-300);
6070 -p.max(1e-300).ln()
6071 }
6072
6073 fn segment_loss(model: &Model, cfg: &Config, steps: &[(u32, u8)]) -> f64 {
6074 if steps.is_empty() {
6075 return 0.0;
6076 }
6077 let mut scratch = ScratchBuffers::new(cfg);
6078 let mut state = model.new_state();
6079 let mut loss = 0.0f64;
6080 for &(input, target) in steps {
6081 let logits = model.forward(&mut scratch, input, &mut state);
6082 loss += softmax_loss(logits, target);
6083 }
6084 loss / (steps.len() as f64)
6085 }
6086
6087 fn segment_grads(model: &Model, cfg: &Config, steps: &[(u32, u8)]) -> FullGradState {
6088 let mut scratch = ScratchBuffers::new(cfg);
6089 let mut state = model.new_state();
6090 let mut states = Vec::with_capacity(steps.len() + 1);
6091 let mut traces = Vec::with_capacity(steps.len());
6092 let mut pdfs = Vec::with_capacity(steps.len());
6093 states.push(state.clone());
6094 for &(input, _) in steps {
6095 scratch.set_capture_train_trace(true);
6096 let logits = model.forward(&mut scratch, input, &mut state);
6097 let mut pdf = vec![0.0f64; cfg.vocab_size];
6098 super::super::super::softmax_pdf_floor_with_bias(logits, None, &mut pdf);
6099 pdfs.push(pdf);
6100 traces.push(TokenTrainTrace::from_scratch(&scratch));
6101 states.push(state.clone());
6102 }
6103 let mut grads = model.new_full_grad_state();
6104 let mut recurrent = model.new_recurrent_grad_state();
6105 let scope = TrainScopeMask {
6106 embed: true,
6107 pre_norm: true,
6108 attn_norm: true,
6109 ffn_norm: true,
6110 attn: true,
6111 ffn: true,
6112 head: true,
6113 bias: false,
6114 };
6115 let grad_scale = 1.0f32 / (steps.len() as f32);
6116 for idx in (0..steps.len()).rev() {
6117 model
6118 .accumulate_token_step_gradients(
6119 &mut scratch,
6120 &traces[idx],
6121 &states[idx + 1],
6122 steps[idx].1,
6123 &pdfs[idx],
6124 grad_scale,
6125 scope,
6126 &mut grads,
6127 None,
6128 &mut recurrent,
6129 )
6130 .expect("segment gradient accumulation");
6131 }
6132 grads
6133 }
6134
6135 #[derive(Clone, Copy, Debug)]
6136 enum Probe {
6137 Embed,
6138 LnOutW,
6139 AttnNormW,
6140 OProj,
6141 KProj,
6142 VProj,
6143 FfnKey,
6144 }
6145
6146 fn probe_value(model: &Model, probe: Probe) -> f32 {
6147 match probe {
6148 Probe::Embed => model.embeddings[7],
6149 Probe::LnOutW => model.ln_out_w[5],
6150 Probe::AttnNormW => model.blocks[0].attn_norm_w[9],
6151 Probe::OProj => model.blocks[0].attn.o_proj[23],
6152 Probe::KProj => model.blocks[0].attn.rkv_proj[64 * 64 + 17],
6153 Probe::VProj => model.blocks[0].attn.rkv_proj[2 * 64 * 64 + 29],
6154 Probe::FfnKey => model.blocks[0].ffn.key_w[11],
6155 }
6156 }
6157
6158 fn set_probe(model: &mut Model, probe: Probe, value: f32) {
6159 match probe {
6160 Probe::Embed => model.embeddings[7] = value,
6161 Probe::LnOutW => model.ln_out_w[5] = value,
6162 Probe::AttnNormW => model.blocks[0].attn_norm_w[9] = value,
6163 Probe::OProj => model.blocks[0].attn.o_proj[23] = value,
6164 Probe::KProj => model.blocks[0].attn.rkv_proj[64 * 64 + 17] = value,
6165 Probe::VProj => model.blocks[0].attn.rkv_proj[2 * 64 * 64 + 29] = value,
6166 Probe::FfnKey => model.blocks[0].ffn.key_w[11] = value,
6167 }
6168 }
6169
6170 fn probe_grad(grads: &FullGradState, probe: Probe) -> f32 {
6171 match probe {
6172 Probe::Embed => grads.embeddings[7],
6173 Probe::LnOutW => grads.ln_out_w[5],
6174 Probe::AttnNormW => grads.blocks[0].attn_norm_w[9],
6175 Probe::OProj => grads.blocks[0].attn.o_proj[23],
6176 Probe::KProj => grads.blocks[0].attn.rkv_proj[64 * 64 + 17],
6177 Probe::VProj => grads.blocks[0].attn.rkv_proj[2 * 64 * 64 + 29],
6178 Probe::FfnKey => grads.blocks[0].ffn.key_w[11],
6179 }
6180 }
6181
6182 fn weighted_checksum(data: &[f32]) -> f64 {
6183 data.iter()
6184 .enumerate()
6185 .map(|(i, &v)| (i as f64 + 1.0) * (v as f64))
6186 .sum()
6187 }
6188
6189 #[test]
6190 fn test_config_default() {
6191 let cfg = Config::default();
6192 assert_eq!(cfg.vocab_size, 256);
6193 assert_eq!(cfg.hidden_size, 256);
6194 assert_eq!(cfg.num_layers, 12);
6195 assert_eq!(cfg.num_heads, 4);
6196 assert_eq!(cfg.head_dim, 64);
6197 }
6198
6199 #[test]
6200 fn test_forward_deterministic_snapshot() {
6201 let cfg = Config {
6202 vocab_size: 256,
6203 hidden_size: 64,
6204 num_layers: 2,
6205 num_heads: 1,
6206 head_dim: 64,
6207 intermediate_size: 128,
6208 layer_norm_eps: 1e-5,
6209 group_norm_eps: 64e-5,
6210 decay_low_rank: 16,
6211 a_low_rank: 16,
6212 v_low_rank: 16,
6213 g_low_rank: 32,
6214 };
6215 cfg.validate().expect("valid test config");
6216
6217 let model = Model::new_random(cfg.clone(), 0x1234_5678_9ABC_DEF0).expect("random model");
6218 let mut state = model.new_state();
6219 let mut scratch = ScratchBuffers::new(&cfg);
6220 let tokens = [0u32, 1, 7, 42, 255, 3, 128, 64, 17, 99];
6221
6222 let mut probes = Vec::new();
6223 let mut last_logits = vec![0.0; 8];
6224
6225 for &token in &tokens {
6226 let logits = model.forward(&mut scratch, token, &mut state);
6227 probes.push(logits[0]);
6228 probes.push(logits[1]);
6229 probes.push(logits[2]);
6230 probes.push(logits[42]);
6231 probes.push(logits[127]);
6232 probes.push(logits[255]);
6233 last_logits.copy_from_slice(&logits[0..8]);
6234 }
6235
6236 let probe_checksum = weighted_checksum(&probes);
6237 let last_logits_checksum = weighted_checksum(&last_logits);
6238 let state_att_checksum = weighted_checksum(state.layers[0].att_state.as_slice());
6239 let state_prev_checksum = weighted_checksum(state.layers[1].att_x_prev.as_slice());
6240 let v_first_checksum = weighted_checksum(state.v_first.as_slice());
6241
6242 let expected_probe_checksum = 25.674_967_924_598_604_f64;
6243 let expected_last_logits_checksum = 0.679_873_816_668_987_3_f64;
6244 let expected_state_att_checksum = 129.962_464_237_222_32_f64;
6245 let expected_state_prev_checksum = -231.326_208_570_972_08_f64;
6246 let expected_v_first_checksum = -1.921_361_377_462_744_7_f64;
6247
6248 let tol = 2e-4_f64;
6249 assert!(
6250 (probe_checksum - expected_probe_checksum).abs() <= tol,
6251 "probe_checksum={probe_checksum}"
6252 );
6253 assert!(
6254 (last_logits_checksum - expected_last_logits_checksum).abs() <= tol,
6255 "last_logits_checksum={last_logits_checksum}"
6256 );
6257 assert!(
6258 (state_att_checksum - expected_state_att_checksum).abs() <= tol,
6259 "state_att_checksum={state_att_checksum}"
6260 );
6261 assert!(
6262 (state_prev_checksum - expected_state_prev_checksum).abs() <= tol,
6263 "state_prev_checksum={state_prev_checksum}"
6264 );
6265 assert!(
6266 (v_first_checksum - expected_v_first_checksum).abs() <= tol,
6267 "v_first_checksum={v_first_checksum}"
6268 );
6269 }
6270
6271 #[test]
6272 fn traced_and_untraced_forward_match_exactly() {
6273 let cfg = Config {
6274 vocab_size: 256,
6275 hidden_size: 64,
6276 num_layers: 2,
6277 num_heads: 1,
6278 head_dim: 64,
6279 intermediate_size: 128,
6280 layer_norm_eps: 1e-5,
6281 group_norm_eps: 64e-5,
6282 decay_low_rank: 16,
6283 a_low_rank: 16,
6284 v_low_rank: 16,
6285 g_low_rank: 32,
6286 };
6287 cfg.validate().expect("valid test config");
6288 let model = Model::new_random(cfg.clone(), 0xCAFEBABE).expect("random model");
6289 let mut traced_state = model.new_state();
6290 let mut plain_state = model.new_state();
6291 let mut traced_scratch = ScratchBuffers::new(&cfg);
6292 let mut plain_scratch = ScratchBuffers::new(&cfg);
6293 traced_scratch.set_capture_train_trace(true);
6294 plain_scratch.set_capture_train_trace(false);
6295
6296 let tokens = [3u32, 19, 77, 120, 255, 5, 88, 13, 144, 1, 200];
6297 for &token in &tokens {
6298 let traced_logits = model
6299 .forward(&mut traced_scratch, token, &mut traced_state)
6300 .to_vec();
6301 let plain_logits = model
6302 .forward(&mut plain_scratch, token, &mut plain_state)
6303 .to_vec();
6304 for (a, b) in traced_logits.iter().zip(plain_logits.iter()) {
6305 assert_eq!(a.to_bits(), b.to_bits());
6306 }
6307 assert_eq!(traced_state.v_first_set, plain_state.v_first_set);
6308 for (&a, &b) in traced_state
6309 .v_first
6310 .as_slice()
6311 .iter()
6312 .zip(plain_state.v_first.as_slice())
6313 {
6314 assert_eq!(a.to_bits(), b.to_bits());
6315 }
6316 for (tr_layer, plain_layer) in traced_state.layers.iter().zip(plain_state.layers.iter())
6317 {
6318 for (&a, &b) in tr_layer
6319 .att_x_prev
6320 .as_slice()
6321 .iter()
6322 .zip(plain_layer.att_x_prev.as_slice())
6323 {
6324 assert_eq!(a.to_bits(), b.to_bits());
6325 }
6326 for (&a, &b) in tr_layer
6327 .att_state
6328 .as_slice()
6329 .iter()
6330 .zip(plain_layer.att_state.as_slice())
6331 {
6332 assert_eq!(a.to_bits(), b.to_bits());
6333 }
6334 for (&a, &b) in tr_layer
6335 .ffn_x_prev
6336 .as_slice()
6337 .iter()
6338 .zip(plain_layer.ffn_x_prev.as_slice())
6339 {
6340 assert_eq!(a.to_bits(), b.to_bits());
6341 }
6342 }
6343 }
6344 }
6345
6346 #[test]
6347 fn tbptt_segment_gradients_match_finite_difference() {
6348 let cfg = test_cfg();
6349 cfg.validate().expect("valid test config");
6350 let model = Model::new_random(cfg.clone(), 0xD00D_F00D).expect("random model");
6351 let steps = [(0u32, 1u8), (1, 2), (2, 3)];
6352 let grads = segment_grads(&model, &cfg, &steps);
6353 let eps = 1e-3f32;
6354
6355 for probe in [
6356 Probe::Embed,
6357 Probe::LnOutW,
6358 Probe::AttnNormW,
6359 Probe::OProj,
6360 Probe::KProj,
6361 Probe::VProj,
6362 Probe::FfnKey,
6363 ] {
6364 let analytic = probe_grad(&grads, probe);
6365
6366 let mut plus = model.clone();
6367 let base = probe_value(&plus, probe);
6368 set_probe(&mut plus, probe, base + eps);
6369 let loss_plus = segment_loss(&plus, &cfg, &steps);
6370
6371 let mut minus = model.clone();
6372 set_probe(&mut minus, probe, base - eps);
6373 let loss_minus = segment_loss(&minus, &cfg, &steps);
6374
6375 let numeric = -((loss_plus - loss_minus) / (2.0 * eps as f64)) as f32;
6376 let tol = 5e-2f32.max(analytic.abs().max(numeric.abs()) * 8e-2);
6377 assert!(
6378 (analytic - numeric).abs() <= tol,
6379 "probe={probe:?} analytic={analytic} numeric={numeric} tol={tol}"
6380 );
6381 }
6382 }
6383
6384 #[test]
6385 fn tbptt_sgd_step_reduces_mean_segment_loss() {
6386 let cfg = test_cfg();
6387 cfg.validate().expect("valid test config");
6388 let mut model = Model::new_random(cfg.clone(), 0x1234_5678).expect("random model");
6389 let steps = [(0u32, 1u8), (1, 2), (2, 3), (3, 4)];
6390 let before = segment_loss(&model, &cfg, &steps);
6391
6392 let mut scratch = ScratchBuffers::new(&cfg);
6393 let start_state = model.new_state();
6394 let mut live_state = model.new_state();
6395 let mut adam_t = 0usize;
6396 let scope = TrainScopeMask {
6397 embed: true,
6398 pre_norm: true,
6399 attn_norm: true,
6400 ffn_norm: true,
6401 attn: true,
6402 ffn: true,
6403 head: true,
6404 bias: false,
6405 };
6406
6407 model
6408 .online_train_segment_tbptt(
6409 &mut scratch,
6410 &start_state,
6411 &steps,
6412 scope,
6413 OptimizerKind::Sgd,
6414 1e-3,
6415 0.0,
6416 2,
6417 &mut adam_t,
6418 None,
6419 None,
6420 None,
6421 None,
6422 &mut live_state,
6423 )
6424 .expect("tbptt sgd step");
6425
6426 let after = segment_loss(&model, &cfg, &steps);
6427 assert!(
6428 after < before,
6429 "expected SGD TBPTT step to reduce mean loss: before={before} after={after}"
6430 );
6431 }
6432}