1use anyhow::{Context, Result};
7use std::path::Path;
8use std::time::Instant;
9
10use super::kernel;
11use super::profiling::{NullProfiler, ProfilerSink};
12use super::tensor::Tensor1D;
13use super::weights::Weights;
14
15#[derive(Debug, Clone)]
17pub struct Config {
18 pub vocab_size: usize,
19 pub hidden_size: usize,
20 pub num_layers: usize,
21 pub num_heads: usize,
22 pub head_dim: usize,
23 pub intermediate_size: usize,
24 pub layer_norm_eps: f32,
25 pub group_norm_eps: f32, pub decay_low_rank: usize, pub a_low_rank: usize,
30 pub v_low_rank: usize,
31 pub g_low_rank: usize,
32}
33
34impl Default for Config {
35 fn default() -> Self {
36 Self {
37 vocab_size: 256,
38 hidden_size: 256,
39 num_layers: 12,
40 num_heads: 4, head_dim: 64,
42 intermediate_size: 1024,
43 layer_norm_eps: 1e-5,
44 group_norm_eps: 64e-5,
45 decay_low_rank: 32,
46 a_low_rank: 32,
47 v_low_rank: 32,
48 g_low_rank: 64,
49 }
50 }
51}
52
53#[derive(Clone)]
55pub struct LayerState {
56 pub att_x_prev: Tensor1D,
58 pub att_state: Tensor1D, pub ffn_x_prev: Tensor1D,
62}
63
64impl LayerState {
65 fn new(cfg: &Config) -> Self {
66 let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
67 Self {
68 att_x_prev: Tensor1D::zeros(cfg.hidden_size),
69 att_state: Tensor1D::zeros(state_size),
70 ffn_x_prev: Tensor1D::zeros(cfg.hidden_size),
71 }
72 }
73}
74
75#[derive(Clone)]
77pub struct State {
78 pub layers: Vec<LayerState>,
79 pub v_first: Tensor1D,
81 pub v_first_set: bool,
83}
84
85impl State {
86 pub fn new(cfg: &Config) -> Self {
87 Self {
88 layers: (0..cfg.num_layers).map(|_| LayerState::new(cfg)).collect(),
89 v_first: Tensor1D::zeros(cfg.hidden_size),
90 v_first_set: false,
91 }
92 }
93
94 pub fn reset(&mut self) {
95 self.v_first_set = false;
96 self.v_first.zero();
97 for layer in &mut self.layers {
98 layer.att_x_prev.zero();
99 layer.att_state.zero();
100 layer.ffn_x_prev.zero();
101 }
102 }
103}
104
105struct AttentionWeights {
107 x_r: Tensor1D,
109 x_w: Tensor1D,
110 x_k: Tensor1D,
111 x_v: Tensor1D,
112 x_a: Tensor1D,
113 x_g: Tensor1D,
114
115 rkv_proj: Tensor1D,
118
119 o_proj: Tensor1D,
121
122 w1: Tensor1D, w2: Tensor1D, w0: Tensor1D, a1: Tensor1D, a2: Tensor1D, a0: Tensor1D, v1: Option<Tensor1D>, v2: Option<Tensor1D>, v0: Option<Tensor1D>, g1: Tensor1D, g2: Tensor1D, k_k: Tensor1D, k_a: Tensor1D, r_k: Tensor1D, g_norm_w: Tensor1D, g_norm_b: Tensor1D, }
150
151struct FfnWeights {
153 x_k: Tensor1D, key_w: Tensor1D, value_w: Tensor1D, }
157
158struct BlockWeights {
160 pre_norm_w: Option<Tensor1D>,
162 pre_norm_b: Option<Tensor1D>,
163
164 attn_norm_w: Tensor1D,
166 attn_norm_b: Tensor1D,
167
168 ffn_norm_w: Tensor1D,
170 ffn_norm_b: Tensor1D,
171
172 attn: AttentionWeights,
173 ffn: FfnWeights,
174}
175
176pub struct Model {
178 cfg: Config,
179
180 embeddings: Tensor1D,
182
183 ln_out_w: Tensor1D,
185 ln_out_b: Tensor1D,
186
187 lm_head: Tensor1D,
189
190 blocks: Vec<BlockWeights>,
192}
193
194pub struct ScratchBuffers {
196 x: Tensor1D, x_normed: Tensor1D, xr: Tensor1D, xw: Tensor1D, xk: Tensor1D, xv: Tensor1D, xa: Tensor1D, xg: Tensor1D, r: Tensor1D, k: Tensor1D, v: Tensor1D, w_lora_tmp: Tensor1D, w_decay: Tensor1D, a: Tensor1D, g: Tensor1D, kk: Tensor1D, y: Tensor1D, att_out: Tensor1D, ffn_k: Tensor1D, ffn_out: Tensor1D, logits: Tensor1D, }
218
219impl ScratchBuffers {
220 pub fn new(cfg: &Config) -> Self {
221 let c = cfg.hidden_size;
222 let i = cfg.intermediate_size;
223 let v = cfg.vocab_size;
224 let d_w = cfg.decay_low_rank;
225
226 Self {
227 x: Tensor1D::zeros(c),
228 x_normed: Tensor1D::zeros(c),
229 xr: Tensor1D::zeros(c),
230 xw: Tensor1D::zeros(c),
231 xk: Tensor1D::zeros(c),
232 xv: Tensor1D::zeros(c),
233 xa: Tensor1D::zeros(c),
234 xg: Tensor1D::zeros(c),
235 r: Tensor1D::zeros(c),
236 k: Tensor1D::zeros(c),
237 v: Tensor1D::zeros(c),
238 w_lora_tmp: Tensor1D::zeros(d_w.max(64)), w_decay: Tensor1D::zeros(c),
240 a: Tensor1D::zeros(c),
241 g: Tensor1D::zeros(c),
242 kk: Tensor1D::zeros(c),
243 y: Tensor1D::zeros(c),
244 att_out: Tensor1D::zeros(c),
245 ffn_k: Tensor1D::zeros(i),
246 ffn_out: Tensor1D::zeros(c),
247 logits: Tensor1D::zeros(v),
248 }
249 }
250}
251
252impl Model {
253 fn tensor_from(weights: &Weights, name: &str) -> Result<Tensor1D> {
254 Ok(Tensor1D::from_vec(weights.require(name)?.data().to_vec()))
255 }
256
257 fn optional_tensor_from(weights: &Weights, name: &str) -> Option<Tensor1D> {
258 weights
259 .get(name)
260 .map(|tensor| Tensor1D::from_vec(tensor.data().to_vec()))
261 }
262
263 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
265 let weights = Weights::load(path.as_ref()).context("Failed to load model weights")?;
266
267 let emb = weights.require("model.embeddings.weight")?;
269 let vocab_size = emb.shape()[0];
270 let hidden_size = emb.shape()[1];
271
272 let num_heads = hidden_size / 64; let head_dim = 64;
274
275 let mut num_layers = 0;
277 while weights
278 .get(&format!("model.layers.{}.attn.r_proj.weight", num_layers))
279 .is_some()
280 {
281 num_layers += 1;
282 }
283
284 let ffn_key = weights.require("model.layers.0.ffn.key.weight")?;
286 let intermediate_size = ffn_key.shape()[0];
287
288 let w1 = weights.require("model.layers.0.attn.w_lora.lora.0.weight")?;
290 let decay_low_rank = w1.shape()[0];
291
292 let a1 = weights.require("model.layers.0.attn.a_lora.lora.0.weight")?;
293 let a_low_rank = a1.shape()[0];
294
295 let g1 = weights.require("model.layers.0.attn.g_lora.lora.0.weight")?;
296 let g_low_rank = g1.shape()[0];
297
298 let v_low_rank = if num_layers > 1 {
300 if let Some(v1) = weights.get("model.layers.1.attn.v_lora.lora.0.weight") {
301 v1.shape()[0]
302 } else {
303 32
304 }
305 } else {
306 32
307 };
308
309 let cfg = Config {
310 vocab_size,
311 hidden_size,
312 num_layers,
313 num_heads,
314 head_dim,
315 intermediate_size,
316 layer_norm_eps: 1e-5,
317 group_norm_eps: 64e-5,
318 decay_low_rank,
319 a_low_rank,
320 v_low_rank,
321 g_low_rank,
322 };
323
324 let embeddings = Self::tensor_from(&weights, "model.embeddings.weight")?;
326
327 let ln_out_w = Self::tensor_from(&weights, "model.norm.weight")?;
329 let ln_out_b = Self::tensor_from(&weights, "model.norm.bias")?;
330
331 let lm_head = Self::tensor_from(&weights, "lm_head.weight")?;
333
334 let mut blocks = Vec::with_capacity(num_layers);
336 for i in 0..num_layers {
337 let prefix = format!("model.layers.{}", i);
338
339 let (pre_norm_w, pre_norm_b) = if i == 0 {
341 (
342 Some(Self::tensor_from(
343 &weights,
344 &format!("{}.pre_norm.weight", prefix),
345 )?),
346 Some(Self::tensor_from(
347 &weights,
348 &format!("{}.pre_norm.bias", prefix),
349 )?),
350 )
351 } else {
352 (None, None)
353 };
354
355 let attn_norm_w = Self::tensor_from(&weights, &format!("{}.attn_norm.weight", prefix))?;
357 let attn_norm_b = Self::tensor_from(&weights, &format!("{}.attn_norm.bias", prefix))?;
358 let ffn_norm_w = Self::tensor_from(&weights, &format!("{}.ffn_norm.weight", prefix))?;
359 let ffn_norm_b = Self::tensor_from(&weights, &format!("{}.ffn_norm.bias", prefix))?;
360
361 let r_proj_data = weights
364 .require(&format!("{}.attn.r_proj.weight", prefix))?
365 .data();
366 let k_proj_data = weights
367 .require(&format!("{}.attn.k_proj.weight", prefix))?
368 .data();
369 let v_proj_data = weights
370 .require(&format!("{}.attn.v_proj.weight", prefix))?
371 .data();
372
373 let proj_size = hidden_size * hidden_size;
375 let mut rkv_proj = Tensor1D::zeros(3 * proj_size);
376 rkv_proj.as_mut_slice()[0..proj_size].copy_from_slice(r_proj_data);
377 rkv_proj.as_mut_slice()[proj_size..2 * proj_size].copy_from_slice(k_proj_data);
378 rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size].copy_from_slice(v_proj_data);
379
380 let attn = AttentionWeights {
381 x_r: Self::tensor_from(&weights, &format!("{}.attn.x_r", prefix))?,
382 x_w: Self::tensor_from(&weights, &format!("{}.attn.x_w", prefix))?,
383 x_k: Self::tensor_from(&weights, &format!("{}.attn.x_k", prefix))?,
384 x_v: Self::tensor_from(&weights, &format!("{}.attn.x_v", prefix))?,
385 x_a: Self::tensor_from(&weights, &format!("{}.attn.x_a", prefix))?,
386 x_g: Self::tensor_from(&weights, &format!("{}.attn.x_g", prefix))?,
387
388 rkv_proj,
389 o_proj: Self::tensor_from(&weights, &format!("{}.attn.o_proj.weight", prefix))?,
390
391 w1: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.0.weight", prefix))?,
392 w2: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.2.weight", prefix))?,
393 w0: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.2.bias", prefix))?,
394
395 a1: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.0.weight", prefix))?,
396 a2: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.2.weight", prefix))?,
397 a0: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.2.bias", prefix))?,
398
399 v1: Self::optional_tensor_from(
400 &weights,
401 &format!("{}.attn.v_lora.lora.0.weight", prefix),
402 ),
403 v2: Self::optional_tensor_from(
404 &weights,
405 &format!("{}.attn.v_lora.lora.2.weight", prefix),
406 ),
407 v0: Self::optional_tensor_from(
408 &weights,
409 &format!("{}.attn.v_lora.lora.2.bias", prefix),
410 ),
411
412 g1: Self::tensor_from(&weights, &format!("{}.attn.g_lora.lora.0.weight", prefix))?,
413 g2: Self::tensor_from(&weights, &format!("{}.attn.g_lora.lora.2.weight", prefix))?,
414
415 k_k: Self::tensor_from(&weights, &format!("{}.attn.k_k", prefix))?,
416 k_a: Self::tensor_from(&weights, &format!("{}.attn.k_a", prefix))?,
417 r_k: Self::tensor_from(&weights, &format!("{}.attn.r_k", prefix))?,
418
419 g_norm_w: Self::tensor_from(&weights, &format!("{}.attn.g_norm.weight", prefix))?,
420 g_norm_b: Self::tensor_from(&weights, &format!("{}.attn.g_norm.bias", prefix))?,
421 };
422
423 let ffn = FfnWeights {
425 x_k: Self::tensor_from(&weights, &format!("{}.ffn.x_k", prefix))?,
426 key_w: Self::tensor_from(&weights, &format!("{}.ffn.key.weight", prefix))?,
427 value_w: Self::tensor_from(&weights, &format!("{}.ffn.value.weight", prefix))?,
428 };
429
430 blocks.push(BlockWeights {
431 pre_norm_w,
432 pre_norm_b,
433 attn_norm_w,
434 attn_norm_b,
435 ffn_norm_w,
436 ffn_norm_b,
437 attn,
438 ffn,
439 });
440 }
441
442 Ok(Self {
443 cfg,
444 embeddings,
445 ln_out_w,
446 ln_out_b,
447 lm_head,
448 blocks,
449 })
450 }
451
452 pub fn config(&self) -> &Config {
454 &self.cfg
455 }
456
457 pub fn new_state(&self) -> State {
459 State::new(&self.cfg)
460 }
461
462 #[inline(never)]
465 pub fn forward<'a>(
466 &'a self,
467 scratch: &'a mut ScratchBuffers,
468 token: u32,
469 state: &mut State,
470 ) -> &'a [f32] {
471 let mut sink = NullProfiler;
472 self.forward_with_sink(scratch, token, state, &mut sink)
473 }
474
475 #[inline(never)]
477 pub fn forward_with_profiler<'a, S: ProfilerSink>(
478 &'a self,
479 scratch: &'a mut ScratchBuffers,
480 token: u32,
481 state: &mut State,
482 profiler: &mut S,
483 ) -> &'a [f32] {
484 self.forward_with_sink(scratch, token, state, profiler)
485 }
486
487 #[inline(never)]
488 fn forward_with_sink<'a, S: ProfilerSink>(
489 &'a self,
490 scratch: &'a mut ScratchBuffers,
491 token: u32,
492 state: &mut State,
493 profiler: &mut S,
494 ) -> &'a [f32] {
495 let c = self.cfg.hidden_size;
496 let _h = self.cfg.num_heads;
497 let _n = self.cfg.head_dim;
498 let num_layers = self.cfg.num_layers;
499
500 let emb_offset = token as usize * c;
502 let emb_slice = &self.embeddings.as_slice()[emb_offset..emb_offset + c];
503 scratch.x.as_mut_slice().copy_from_slice(emb_slice);
504
505 profiler.begin_token();
506
507 unsafe {
508 for layer_idx in 0..num_layers {
510 if let (Some(w), Some(b)) = (
512 &self.blocks[layer_idx].pre_norm_w,
513 &self.blocks[layer_idx].pre_norm_b,
514 ) {
515 kernel::layer_norm_avx(
516 scratch.x.as_ptr(),
517 w.as_ptr(),
518 b.as_ptr(),
519 scratch.x.as_mut_ptr(),
520 c,
521 self.cfg.layer_norm_eps,
522 );
523 }
524
525 kernel::layer_norm_avx(
527 scratch.x.as_ptr(),
528 self.blocks[layer_idx].attn_norm_w.as_ptr(),
529 self.blocks[layer_idx].attn_norm_b.as_ptr(),
530 scratch.x_normed.as_mut_ptr(),
531 c,
532 self.cfg.layer_norm_eps,
533 );
534
535 let attn_start = Instant::now();
536 self.attention_forward_impl(scratch, layer_idx, state);
537 profiler.record_attention(layer_idx, attn_start.elapsed());
538
539 kernel::add_avx(
541 scratch.x.as_ptr(),
542 scratch.att_out.as_ptr(),
543 scratch.x.as_mut_ptr(),
544 c,
545 );
546
547 kernel::layer_norm_avx(
549 scratch.x.as_ptr(),
550 self.blocks[layer_idx].ffn_norm_w.as_ptr(),
551 self.blocks[layer_idx].ffn_norm_b.as_ptr(),
552 scratch.x_normed.as_mut_ptr(),
553 c,
554 self.cfg.layer_norm_eps,
555 );
556
557 let ffn_start = Instant::now();
558 self.ffn_forward_impl(scratch, layer_idx, &mut state.layers[layer_idx]);
559 profiler.record_ffn(layer_idx, ffn_start.elapsed());
560
561 kernel::add_avx(
563 scratch.x.as_ptr(),
564 scratch.ffn_out.as_ptr(),
565 scratch.x.as_mut_ptr(),
566 c,
567 );
568 }
569
570 kernel::layer_norm_avx(
572 scratch.x.as_ptr(),
573 self.ln_out_w.as_ptr(),
574 self.ln_out_b.as_ptr(),
575 scratch.x_normed.as_mut_ptr(),
576 c,
577 self.cfg.layer_norm_eps,
578 );
579
580 kernel::gemv_avx(
582 self.lm_head.as_ptr(),
583 scratch.x_normed.as_ptr(),
584 scratch.logits.as_mut_ptr(),
585 self.cfg.vocab_size,
586 c,
587 );
588 }
589
590 scratch.logits.as_slice()
591 }
592
593 #[inline(always)]
594 unsafe fn attention_forward_impl(
595 &self,
596 scratch: &mut ScratchBuffers,
597 layer_idx: usize,
598 state: &mut State,
599 ) {
600 let attn = &self.blocks[layer_idx].attn;
601 let layer_state = &mut state.layers[layer_idx];
602 let c = self.cfg.hidden_size;
603 let h = self.cfg.num_heads;
604 let n = self.cfg.head_dim;
605 let d_w = self.cfg.decay_low_rank;
606 let d_a = self.cfg.a_low_rank;
607 let d_g = self.cfg.g_low_rank;
608
609 kernel::token_shift_multi6_avx(
610 scratch.x_normed.as_ptr(),
611 layer_state.att_x_prev.as_ptr(),
612 attn.x_r.as_ptr(),
613 attn.x_w.as_ptr(),
614 attn.x_k.as_ptr(),
615 attn.x_v.as_ptr(),
616 attn.x_a.as_ptr(),
617 attn.x_g.as_ptr(),
618 scratch.xr.as_mut_ptr(),
619 scratch.xw.as_mut_ptr(),
620 scratch.xk.as_mut_ptr(),
621 scratch.xv.as_mut_ptr(),
622 scratch.xa.as_mut_ptr(),
623 scratch.xg.as_mut_ptr(),
624 c,
625 );
626
627 kernel::copy(
629 scratch.x_normed.as_ptr(),
630 layer_state.att_x_prev.as_mut_ptr(),
631 c,
632 );
633
634 let proj_size = c * c;
637 kernel::gemv_avx(
638 attn.rkv_proj.as_ptr(),
639 scratch.xr.as_ptr(),
640 scratch.r.as_mut_ptr(),
641 c,
642 c,
643 );
644 kernel::gemv_avx(
645 attn.rkv_proj.as_ptr().add(proj_size),
646 scratch.xk.as_ptr(),
647 scratch.k.as_mut_ptr(),
648 c,
649 c,
650 );
651 kernel::gemv_avx(
652 attn.rkv_proj.as_ptr().add(2 * proj_size),
653 scratch.xv.as_ptr(),
654 scratch.v.as_mut_ptr(),
655 c,
656 c,
657 );
658
659 kernel::gemv_avx(
662 attn.w1.as_ptr(),
663 scratch.xw.as_ptr(),
664 scratch.w_lora_tmp.as_mut_ptr(),
665 d_w,
666 c,
667 );
668 kernel::tanh_avx(
670 scratch.w_lora_tmp.as_ptr(),
671 scratch.w_lora_tmp.as_mut_ptr(),
672 d_w,
673 );
674 kernel::gemv_avx(
676 attn.w2.as_ptr(),
677 scratch.w_lora_tmp.as_ptr(),
678 scratch.w_decay.as_mut_ptr(),
679 c,
680 d_w,
681 );
682 kernel::add_avx(
684 scratch.w_decay.as_ptr(),
685 attn.w0.as_ptr(),
686 scratch.w_decay.as_mut_ptr(),
687 c,
688 );
689 let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
691 kernel::sigmoid_avx(scratch.w_decay.as_ptr(), scratch.w_decay.as_mut_ptr(), c);
692 kernel::exp_neg_scaled_inplace(scratch.w_decay.as_mut_ptr(), inv_sqrt_e, c);
693
694 kernel::gemv_avx(
696 attn.a1.as_ptr(),
697 scratch.xa.as_ptr(),
698 scratch.w_lora_tmp.as_mut_ptr(),
699 d_a,
700 c,
701 );
702 kernel::gemv_avx(
703 attn.a2.as_ptr(),
704 scratch.w_lora_tmp.as_ptr(),
705 scratch.a.as_mut_ptr(),
706 c,
707 d_a,
708 );
709 kernel::add_avx(
710 scratch.a.as_ptr(),
711 attn.a0.as_ptr(),
712 scratch.a.as_mut_ptr(),
713 c,
714 );
715 kernel::sigmoid_avx(scratch.a.as_ptr(), scratch.a.as_mut_ptr(), c);
716
717 kernel::gemv_avx(
719 attn.g1.as_ptr(),
720 scratch.xg.as_ptr(),
721 scratch.w_lora_tmp.as_mut_ptr(),
722 d_g,
723 c,
724 );
725 kernel::sigmoid_avx(
726 scratch.w_lora_tmp.as_ptr(),
727 scratch.w_lora_tmp.as_mut_ptr(),
728 d_g,
729 );
730 kernel::gemv_avx(
731 attn.g2.as_ptr(),
732 scratch.w_lora_tmp.as_ptr(),
733 scratch.g.as_mut_ptr(),
734 c,
735 d_g,
736 );
737
738 if layer_idx == 0 {
740 state.v_first.copy_from(&scratch.v);
742 state.v_first_set = true;
743 } else if state.v_first_set {
744 if let (Some(v1), Some(v2), Some(v0)) = (&attn.v1, &attn.v2, &attn.v0) {
745 let d_v = self.cfg.v_low_rank;
746 kernel::gemv_avx(
748 v1.as_ptr(),
749 scratch.xv.as_ptr(),
750 scratch.w_lora_tmp.as_mut_ptr(),
751 d_v,
752 c,
753 );
754 kernel::gemv_avx(
755 v2.as_ptr(),
756 scratch.w_lora_tmp.as_ptr(),
757 scratch.att_out.as_mut_ptr(), c,
759 d_v,
760 );
761 kernel::add_avx(
762 scratch.att_out.as_ptr(),
763 v0.as_ptr(),
764 scratch.att_out.as_mut_ptr(),
765 c,
766 );
767 kernel::sigmoid_avx(scratch.att_out.as_ptr(), scratch.att_out.as_mut_ptr(), c);
768 for i in 0..c {
770 let nu = scratch.att_out[i];
771 scratch.v[i] += (state.v_first[i] - scratch.v[i]) * nu;
772 }
773 }
774 }
775
776 kernel::mul_avx(
778 scratch.k.as_ptr(),
779 attn.k_k.as_ptr(),
780 scratch.kk.as_mut_ptr(),
781 c,
782 );
783 for head in 0..h {
785 let offset = head * n;
786 kernel::l2_normalize_avx(
787 scratch.kk.as_ptr().add(offset),
788 scratch.kk.as_mut_ptr().add(offset),
789 n,
790 1e-12,
791 );
792 }
793
794 for i in 0..c {
796 let scale = 1.0 + (scratch.a[i] - 1.0) * attn.k_a[i];
797 scratch.k[i] *= scale;
798 }
799
800 kernel::rwkv7_wkv_update_avx(
802 layer_state.att_state.as_mut_ptr(),
803 scratch.w_decay.as_ptr(),
804 scratch.k.as_ptr(),
805 scratch.v.as_ptr(),
806 scratch.kk.as_ptr(),
807 scratch.a.as_ptr(),
808 scratch.r.as_ptr(),
809 scratch.y.as_mut_ptr(),
810 h,
811 n,
812 );
813
814 kernel::group_norm_avx(
816 scratch.y.as_ptr(),
817 attn.g_norm_w.as_ptr(),
818 attn.g_norm_b.as_ptr(),
819 scratch.y.as_mut_ptr(),
820 h,
821 n,
822 self.cfg.group_norm_eps,
823 );
824
825 for head in 0..h {
827 let offset = head * n;
828 let mut alpha = 0.0f32;
829 for j in 0..n {
830 alpha += scratch.r[offset + j] * scratch.k[offset + j] * attn.r_k[head * n + j];
831 }
832 for j in 0..n {
833 scratch.y[offset + j] += alpha * scratch.v[offset + j];
834 }
835 }
836
837 kernel::mul_avx(
839 scratch.y.as_ptr(),
840 scratch.g.as_ptr(),
841 scratch.y.as_mut_ptr(),
842 c,
843 );
844
845 kernel::gemv_avx(
847 attn.o_proj.as_ptr(),
848 scratch.y.as_ptr(),
849 scratch.att_out.as_mut_ptr(),
850 c,
851 c,
852 );
853 }
854
855 #[inline(always)]
856 unsafe fn ffn_forward_impl(
857 &self,
858 scratch: &mut ScratchBuffers,
859 layer_idx: usize,
860 layer_state: &mut LayerState,
861 ) {
862 let ffn = &self.blocks[layer_idx].ffn;
863 let c = self.cfg.hidden_size;
864 let i = self.cfg.intermediate_size;
865
866 kernel::token_shift_avx(
868 scratch.x_normed.as_ptr(),
869 layer_state.ffn_x_prev.as_ptr(),
870 ffn.x_k.as_ptr(),
871 scratch.xk.as_mut_ptr(),
872 c,
873 );
874
875 kernel::copy(
877 scratch.x_normed.as_ptr(),
878 layer_state.ffn_x_prev.as_mut_ptr(),
879 c,
880 );
881
882 kernel::gemv_avx(
884 ffn.key_w.as_ptr(),
885 scratch.xk.as_ptr(),
886 scratch.ffn_k.as_mut_ptr(),
887 i,
888 c,
889 );
890 kernel::relu_squared_avx(scratch.ffn_k.as_ptr(), scratch.ffn_k.as_mut_ptr(), i);
891
892 kernel::gemv_avx(
894 ffn.value_w.as_ptr(),
895 scratch.ffn_k.as_ptr(),
896 scratch.ffn_out.as_mut_ptr(),
897 c,
898 i,
899 );
900 }
901}
902
903#[cfg(test)]
904mod tests {
905 use super::*;
906
907 #[test]
908 fn test_config_default() {
909 let cfg = Config::default();
910 assert_eq!(cfg.vocab_size, 256);
911 assert_eq!(cfg.hidden_size, 256);
912 assert_eq!(cfg.num_layers, 12);
913 assert_eq!(cfg.num_heads, 4);
914 assert_eq!(cfg.head_dim, 64);
915 }
916}