rwkvzip/rwkv7/
model.rs

1//! RWKV7 model implementation with SIMD-optimized inference.
2//!
3//! This is a high-performance implementation specifically for x86_64 CPUs.
4//! Single-token inference is the primary use case (streaming compression).
5
6use anyhow::{Context, Result};
7use std::path::Path;
8use std::time::Instant;
9
10use super::kernel;
11use super::profiling::{NullProfiler, ProfilerSink};
12use super::tensor::Tensor1D;
13use super::weights::Weights;
14
15/// Model configuration.
16#[derive(Debug, Clone)]
17pub struct Config {
18    pub vocab_size: usize,
19    pub hidden_size: usize,
20    pub num_layers: usize,
21    pub num_heads: usize,
22    pub head_dim: usize,
23    pub intermediate_size: usize,
24    pub layer_norm_eps: f32,
25    pub group_norm_eps: f32, // 64e-5 per reference
26
27    // Low-rank dimensions
28    pub decay_low_rank: usize, // w_lora
29    pub a_low_rank: usize,
30    pub v_low_rank: usize,
31    pub g_low_rank: usize,
32}
33
34impl Default for Config {
35    fn default() -> Self {
36        Self {
37            vocab_size: 256,
38            hidden_size: 256,
39            num_layers: 12,
40            num_heads: 4, // 256 / 64
41            head_dim: 64,
42            intermediate_size: 1024,
43            layer_norm_eps: 1e-5,
44            group_norm_eps: 64e-5,
45            decay_low_rank: 32,
46            a_low_rank: 32,
47            v_low_rank: 32,
48            g_low_rank: 64,
49        }
50    }
51}
52
53/// Per-layer state for RWKV7.
54#[derive(Clone)]
55pub struct LayerState {
56    /// Previous token embedding for attention time-shift (hidden_size,)
57    pub att_x_prev: Tensor1D,
58    /// Attention state matrix (num_heads, head_dim, head_dim) = (H, N, N)
59    pub att_state: Tensor1D, // Flat for SIMD access
60    /// Previous token embedding for FFN time-shift (hidden_size,)
61    pub ffn_x_prev: Tensor1D,
62}
63
64impl LayerState {
65    fn new(cfg: &Config) -> Self {
66        let state_size = cfg.num_heads * cfg.head_dim * cfg.head_dim;
67        Self {
68            att_x_prev: Tensor1D::zeros(cfg.hidden_size),
69            att_state: Tensor1D::zeros(state_size),
70            ffn_x_prev: Tensor1D::zeros(cfg.hidden_size),
71        }
72    }
73}
74
75/// Full model state.
76#[derive(Clone)]
77pub struct State {
78    pub layers: Vec<LayerState>,
79    /// First layer's value output (for residual connection) - pre-allocated
80    pub v_first: Tensor1D,
81    /// Flag to indicate if v_first has been set
82    pub v_first_set: bool,
83}
84
85impl State {
86    pub fn new(cfg: &Config) -> Self {
87        Self {
88            layers: (0..cfg.num_layers).map(|_| LayerState::new(cfg)).collect(),
89            v_first: Tensor1D::zeros(cfg.hidden_size),
90            v_first_set: false,
91        }
92    }
93
94    pub fn reset(&mut self) {
95        self.v_first_set = false;
96        self.v_first.zero();
97        for layer in &mut self.layers {
98            layer.att_x_prev.zero();
99            layer.att_state.zero();
100            layer.ffn_x_prev.zero();
101        }
102    }
103}
104
105/// Weights for a single attention layer.
106struct AttentionWeights {
107    // Token shift mixing factors
108    x_r: Tensor1D,
109    x_w: Tensor1D,
110    x_k: Tensor1D,
111    x_v: Tensor1D,
112    x_a: Tensor1D,
113    x_g: Tensor1D,
114
115    // Packed r/k/v projections for parallel computation
116    // Layout: [r_proj (C*C), k_proj (C*C), v_proj (C*C)]
117    rkv_proj: Tensor1D,
118
119    // Output projection (stored transposed for efficient gemv)
120    o_proj: Tensor1D,
121
122    // Low-rank W: w = tanh(x @ w1) @ w2 + w0
123    w1: Tensor1D, // (C, D_w)
124    w2: Tensor1D, // (D_w, C)
125    w0: Tensor1D, // (C,)
126
127    // Low-rank A: a = sigmoid(x @ a1 @ a2 + a0)
128    a1: Tensor1D, // (C, D_a)
129    a2: Tensor1D, // (D_a, C)
130    a0: Tensor1D, // (C,)
131
132    // Low-rank V (layers > 0): nu = sigmoid(x @ v1 @ v2 + v0)
133    v1: Option<Tensor1D>, // (C, D_v)
134    v2: Option<Tensor1D>, // (D_v, C)
135    v0: Option<Tensor1D>, // (C,)
136
137    // Low-rank G: g = sigmoid(x @ g1) @ g2
138    g1: Tensor1D, // (C, D_g)
139    g2: Tensor1D, // (D_g, C)
140
141    // Key scaling
142    k_k: Tensor1D, // (C,)
143    k_a: Tensor1D, // (C,)
144    r_k: Tensor1D, // (H, N)
145
146    // Group norm for output
147    g_norm_w: Tensor1D, // (C,)
148    g_norm_b: Tensor1D, // (C,)
149}
150
151/// Weights for a single FFN layer.
152struct FfnWeights {
153    x_k: Tensor1D,     // (C,) time shift mix
154    key_w: Tensor1D,   // (C, I) -> relu(x @ W)^2
155    value_w: Tensor1D, // (I, C)
156}
157
158/// Weights for a single block.
159struct BlockWeights {
160    // Pre-norm (layer 0 only)
161    pre_norm_w: Option<Tensor1D>,
162    pre_norm_b: Option<Tensor1D>,
163
164    // Attention norm
165    attn_norm_w: Tensor1D,
166    attn_norm_b: Tensor1D,
167
168    // FFN norm
169    ffn_norm_w: Tensor1D,
170    ffn_norm_b: Tensor1D,
171
172    attn: AttentionWeights,
173    ffn: FfnWeights,
174}
175
176/// RWKV7 model.
177pub struct Model {
178    cfg: Config,
179
180    // Embeddings (vocab_size, hidden_size)
181    embeddings: Tensor1D,
182
183    // Output norm
184    ln_out_w: Tensor1D,
185    ln_out_b: Tensor1D,
186
187    // LM head (vocab_size, hidden_size)
188    lm_head: Tensor1D,
189
190    // Layers
191    blocks: Vec<BlockWeights>,
192}
193
194/// Pre-allocated scratch buffers to avoid allocations in hot path.
195pub struct ScratchBuffers {
196    x: Tensor1D,          // Current hidden state
197    x_normed: Tensor1D,   // After layer norm
198    xr: Tensor1D,         // Token-shifted for r
199    xw: Tensor1D,         // Token-shifted for w
200    xk: Tensor1D,         // Token-shifted for k
201    xv: Tensor1D,         // Token-shifted for v
202    xa: Tensor1D,         // Token-shifted for a
203    xg: Tensor1D,         // Token-shifted for g
204    r: Tensor1D,          // Receptance
205    k: Tensor1D,          // Key
206    v: Tensor1D,          // Value
207    w_lora_tmp: Tensor1D, // Low-rank temp
208    w_decay: Tensor1D,    // Decay factor
209    a: Tensor1D,          // Gate a
210    g: Tensor1D,          // Gate g
211    kk: Tensor1D,         // Normalized key
212    y: Tensor1D,          // WKV output
213    att_out: Tensor1D,    // Attention output
214    ffn_k: Tensor1D,      // FFN key
215    ffn_out: Tensor1D,    // FFN output
216    logits: Tensor1D,     // Output logits
217}
218
219impl ScratchBuffers {
220    pub fn new(cfg: &Config) -> Self {
221        let c = cfg.hidden_size;
222        let i = cfg.intermediate_size;
223        let v = cfg.vocab_size;
224        let d_w = cfg.decay_low_rank;
225
226        Self {
227            x: Tensor1D::zeros(c),
228            x_normed: Tensor1D::zeros(c),
229            xr: Tensor1D::zeros(c),
230            xw: Tensor1D::zeros(c),
231            xk: Tensor1D::zeros(c),
232            xv: Tensor1D::zeros(c),
233            xa: Tensor1D::zeros(c),
234            xg: Tensor1D::zeros(c),
235            r: Tensor1D::zeros(c),
236            k: Tensor1D::zeros(c),
237            v: Tensor1D::zeros(c),
238            w_lora_tmp: Tensor1D::zeros(d_w.max(64)), // Max of all low-rank dims
239            w_decay: Tensor1D::zeros(c),
240            a: Tensor1D::zeros(c),
241            g: Tensor1D::zeros(c),
242            kk: Tensor1D::zeros(c),
243            y: Tensor1D::zeros(c),
244            att_out: Tensor1D::zeros(c),
245            ffn_k: Tensor1D::zeros(i),
246            ffn_out: Tensor1D::zeros(c),
247            logits: Tensor1D::zeros(v),
248        }
249    }
250}
251
252impl Model {
253    fn tensor_from(weights: &Weights, name: &str) -> Result<Tensor1D> {
254        Ok(Tensor1D::from_vec(weights.require(name)?.data().to_vec()))
255    }
256
257    fn optional_tensor_from(weights: &Weights, name: &str) -> Option<Tensor1D> {
258        weights
259            .get(name)
260            .map(|tensor| Tensor1D::from_vec(tensor.data().to_vec()))
261    }
262
263    /// Load model from safetensors file.
264    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
265        let weights = Weights::load(path.as_ref()).context("Failed to load model weights")?;
266
267        // Infer config from weights
268        let emb = weights.require("model.embeddings.weight")?;
269        let vocab_size = emb.shape()[0];
270        let hidden_size = emb.shape()[1];
271
272        let num_heads = hidden_size / 64; // Assume head_dim=64
273        let head_dim = 64;
274
275        // Count layers by looking for layer weights
276        let mut num_layers = 0;
277        while weights
278            .get(&format!("model.layers.{}.attn.r_proj.weight", num_layers))
279            .is_some()
280        {
281            num_layers += 1;
282        }
283
284        // Get intermediate size from FFN
285        let ffn_key = weights.require("model.layers.0.ffn.key.weight")?;
286        let intermediate_size = ffn_key.shape()[0];
287
288        // Get low-rank dimensions
289        let w1 = weights.require("model.layers.0.attn.w_lora.lora.0.weight")?;
290        let decay_low_rank = w1.shape()[0];
291
292        let a1 = weights.require("model.layers.0.attn.a_lora.lora.0.weight")?;
293        let a_low_rank = a1.shape()[0];
294
295        let g1 = weights.require("model.layers.0.attn.g_lora.lora.0.weight")?;
296        let g_low_rank = g1.shape()[0];
297
298        // v_low_rank from layer 1 (layer 0 doesn't have it)
299        let v_low_rank = if num_layers > 1 {
300            if let Some(v1) = weights.get("model.layers.1.attn.v_lora.lora.0.weight") {
301                v1.shape()[0]
302            } else {
303                32
304            }
305        } else {
306            32
307        };
308
309        let cfg = Config {
310            vocab_size,
311            hidden_size,
312            num_layers,
313            num_heads,
314            head_dim,
315            intermediate_size,
316            layer_norm_eps: 1e-5,
317            group_norm_eps: 64e-5,
318            decay_low_rank,
319            a_low_rank,
320            v_low_rank,
321            g_low_rank,
322        };
323
324        // Load embeddings
325        let embeddings = Self::tensor_from(&weights, "model.embeddings.weight")?;
326
327        // Load output norm
328        let ln_out_w = Self::tensor_from(&weights, "model.norm.weight")?;
329        let ln_out_b = Self::tensor_from(&weights, "model.norm.bias")?;
330
331        // Load LM head
332        let lm_head = Self::tensor_from(&weights, "lm_head.weight")?;
333
334        // Load blocks
335        let mut blocks = Vec::with_capacity(num_layers);
336        for i in 0..num_layers {
337            let prefix = format!("model.layers.{}", i);
338
339            // Pre-norm (layer 0 only)
340            let (pre_norm_w, pre_norm_b) = if i == 0 {
341                (
342                    Some(Self::tensor_from(
343                        &weights,
344                        &format!("{}.pre_norm.weight", prefix),
345                    )?),
346                    Some(Self::tensor_from(
347                        &weights,
348                        &format!("{}.pre_norm.bias", prefix),
349                    )?),
350                )
351            } else {
352                (None, None)
353            };
354
355            // Norms
356            let attn_norm_w = Self::tensor_from(&weights, &format!("{}.attn_norm.weight", prefix))?;
357            let attn_norm_b = Self::tensor_from(&weights, &format!("{}.attn_norm.bias", prefix))?;
358            let ffn_norm_w = Self::tensor_from(&weights, &format!("{}.ffn_norm.weight", prefix))?;
359            let ffn_norm_b = Self::tensor_from(&weights, &format!("{}.ffn_norm.bias", prefix))?;
360
361            // Attention weights
362            // Load r/k/v projections and pack them contiguously
363            let r_proj_data = weights
364                .require(&format!("{}.attn.r_proj.weight", prefix))?
365                .data();
366            let k_proj_data = weights
367                .require(&format!("{}.attn.k_proj.weight", prefix))?
368                .data();
369            let v_proj_data = weights
370                .require(&format!("{}.attn.v_proj.weight", prefix))?
371                .data();
372
373            // Create packed RKV tensor: [r_proj, k_proj, v_proj]
374            let proj_size = hidden_size * hidden_size;
375            let mut rkv_proj = Tensor1D::zeros(3 * proj_size);
376            rkv_proj.as_mut_slice()[0..proj_size].copy_from_slice(r_proj_data);
377            rkv_proj.as_mut_slice()[proj_size..2 * proj_size].copy_from_slice(k_proj_data);
378            rkv_proj.as_mut_slice()[2 * proj_size..3 * proj_size].copy_from_slice(v_proj_data);
379
380            let attn = AttentionWeights {
381                x_r: Self::tensor_from(&weights, &format!("{}.attn.x_r", prefix))?,
382                x_w: Self::tensor_from(&weights, &format!("{}.attn.x_w", prefix))?,
383                x_k: Self::tensor_from(&weights, &format!("{}.attn.x_k", prefix))?,
384                x_v: Self::tensor_from(&weights, &format!("{}.attn.x_v", prefix))?,
385                x_a: Self::tensor_from(&weights, &format!("{}.attn.x_a", prefix))?,
386                x_g: Self::tensor_from(&weights, &format!("{}.attn.x_g", prefix))?,
387
388                rkv_proj,
389                o_proj: Self::tensor_from(&weights, &format!("{}.attn.o_proj.weight", prefix))?,
390
391                w1: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.0.weight", prefix))?,
392                w2: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.2.weight", prefix))?,
393                w0: Self::tensor_from(&weights, &format!("{}.attn.w_lora.lora.2.bias", prefix))?,
394
395                a1: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.0.weight", prefix))?,
396                a2: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.2.weight", prefix))?,
397                a0: Self::tensor_from(&weights, &format!("{}.attn.a_lora.lora.2.bias", prefix))?,
398
399                v1: Self::optional_tensor_from(
400                    &weights,
401                    &format!("{}.attn.v_lora.lora.0.weight", prefix),
402                ),
403                v2: Self::optional_tensor_from(
404                    &weights,
405                    &format!("{}.attn.v_lora.lora.2.weight", prefix),
406                ),
407                v0: Self::optional_tensor_from(
408                    &weights,
409                    &format!("{}.attn.v_lora.lora.2.bias", prefix),
410                ),
411
412                g1: Self::tensor_from(&weights, &format!("{}.attn.g_lora.lora.0.weight", prefix))?,
413                g2: Self::tensor_from(&weights, &format!("{}.attn.g_lora.lora.2.weight", prefix))?,
414
415                k_k: Self::tensor_from(&weights, &format!("{}.attn.k_k", prefix))?,
416                k_a: Self::tensor_from(&weights, &format!("{}.attn.k_a", prefix))?,
417                r_k: Self::tensor_from(&weights, &format!("{}.attn.r_k", prefix))?,
418
419                g_norm_w: Self::tensor_from(&weights, &format!("{}.attn.g_norm.weight", prefix))?,
420                g_norm_b: Self::tensor_from(&weights, &format!("{}.attn.g_norm.bias", prefix))?,
421            };
422
423            // FFN weights
424            let ffn = FfnWeights {
425                x_k: Self::tensor_from(&weights, &format!("{}.ffn.x_k", prefix))?,
426                key_w: Self::tensor_from(&weights, &format!("{}.ffn.key.weight", prefix))?,
427                value_w: Self::tensor_from(&weights, &format!("{}.ffn.value.weight", prefix))?,
428            };
429
430            blocks.push(BlockWeights {
431                pre_norm_w,
432                pre_norm_b,
433                attn_norm_w,
434                attn_norm_b,
435                ffn_norm_w,
436                ffn_norm_b,
437                attn,
438                ffn,
439            });
440        }
441
442        Ok(Self {
443            cfg,
444            embeddings,
445            ln_out_w,
446            ln_out_b,
447            lm_head,
448            blocks,
449        })
450    }
451
452    /// Get model configuration.
453    pub fn config(&self) -> &Config {
454        &self.cfg
455    }
456
457    /// Create new state for this model.
458    pub fn new_state(&self) -> State {
459        State::new(&self.cfg)
460    }
461
462    /// Forward pass for a single token.
463    /// Returns logits for next token prediction.
464    #[inline(never)]
465    pub fn forward<'a>(
466        &'a self,
467        scratch: &'a mut ScratchBuffers,
468        token: u32,
469        state: &mut State,
470    ) -> &'a [f32] {
471        let mut sink = NullProfiler;
472        self.forward_with_sink(scratch, token, state, &mut sink)
473    }
474
475    /// Forward pass that records per-layer timings through a custom sink.
476    #[inline(never)]
477    pub fn forward_with_profiler<'a, S: ProfilerSink>(
478        &'a self,
479        scratch: &'a mut ScratchBuffers,
480        token: u32,
481        state: &mut State,
482        profiler: &mut S,
483    ) -> &'a [f32] {
484        self.forward_with_sink(scratch, token, state, profiler)
485    }
486
487    #[inline(never)]
488    fn forward_with_sink<'a, S: ProfilerSink>(
489        &'a self,
490        scratch: &'a mut ScratchBuffers,
491        token: u32,
492        state: &mut State,
493        profiler: &mut S,
494    ) -> &'a [f32] {
495        let c = self.cfg.hidden_size;
496        let _h = self.cfg.num_heads;
497        let _n = self.cfg.head_dim;
498        let num_layers = self.cfg.num_layers;
499
500        // Get token embedding
501        let emb_offset = token as usize * c;
502        let emb_slice = &self.embeddings.as_slice()[emb_offset..emb_offset + c];
503        scratch.x.as_mut_slice().copy_from_slice(emb_slice);
504
505        profiler.begin_token();
506
507        unsafe {
508            // Process each layer (using index to avoid borrow conflicts)
509            for layer_idx in 0..num_layers {
510                // Pre-norm (layer 0 only)
511                if let (Some(w), Some(b)) = (
512                    &self.blocks[layer_idx].pre_norm_w,
513                    &self.blocks[layer_idx].pre_norm_b,
514                ) {
515                    kernel::layer_norm_avx(
516                        scratch.x.as_ptr(),
517                        w.as_ptr(),
518                        b.as_ptr(),
519                        scratch.x.as_mut_ptr(),
520                        c,
521                        self.cfg.layer_norm_eps,
522                    );
523                }
524
525                // Attention norm
526                kernel::layer_norm_avx(
527                    scratch.x.as_ptr(),
528                    self.blocks[layer_idx].attn_norm_w.as_ptr(),
529                    self.blocks[layer_idx].attn_norm_b.as_ptr(),
530                    scratch.x_normed.as_mut_ptr(),
531                    c,
532                    self.cfg.layer_norm_eps,
533                );
534
535                let attn_start = Instant::now();
536                self.attention_forward_impl(scratch, layer_idx, state);
537                profiler.record_attention(layer_idx, attn_start.elapsed());
538
539                // Add attention residual: x = x + att_out
540                kernel::add_avx(
541                    scratch.x.as_ptr(),
542                    scratch.att_out.as_ptr(),
543                    scratch.x.as_mut_ptr(),
544                    c,
545                );
546
547                // FFN norm
548                kernel::layer_norm_avx(
549                    scratch.x.as_ptr(),
550                    self.blocks[layer_idx].ffn_norm_w.as_ptr(),
551                    self.blocks[layer_idx].ffn_norm_b.as_ptr(),
552                    scratch.x_normed.as_mut_ptr(),
553                    c,
554                    self.cfg.layer_norm_eps,
555                );
556
557                let ffn_start = Instant::now();
558                self.ffn_forward_impl(scratch, layer_idx, &mut state.layers[layer_idx]);
559                profiler.record_ffn(layer_idx, ffn_start.elapsed());
560
561                // Add FFN residual: x = x + ffn_out
562                kernel::add_avx(
563                    scratch.x.as_ptr(),
564                    scratch.ffn_out.as_ptr(),
565                    scratch.x.as_mut_ptr(),
566                    c,
567                );
568            }
569
570            // Output norm
571            kernel::layer_norm_avx(
572                scratch.x.as_ptr(),
573                self.ln_out_w.as_ptr(),
574                self.ln_out_b.as_ptr(),
575                scratch.x_normed.as_mut_ptr(),
576                c,
577                self.cfg.layer_norm_eps,
578            );
579
580            // LM head: logits = x @ lm_head.T
581            kernel::gemv_avx(
582                self.lm_head.as_ptr(),
583                scratch.x_normed.as_ptr(),
584                scratch.logits.as_mut_ptr(),
585                self.cfg.vocab_size,
586                c,
587            );
588        }
589
590        scratch.logits.as_slice()
591    }
592
593    #[inline(always)]
594    unsafe fn attention_forward_impl(
595        &self,
596        scratch: &mut ScratchBuffers,
597        layer_idx: usize,
598        state: &mut State,
599    ) {
600        let attn = &self.blocks[layer_idx].attn;
601        let layer_state = &mut state.layers[layer_idx];
602        let c = self.cfg.hidden_size;
603        let h = self.cfg.num_heads;
604        let n = self.cfg.head_dim;
605        let d_w = self.cfg.decay_low_rank;
606        let d_a = self.cfg.a_low_rank;
607        let d_g = self.cfg.g_low_rank;
608
609        kernel::token_shift_multi6_avx(
610            scratch.x_normed.as_ptr(),
611            layer_state.att_x_prev.as_ptr(),
612            attn.x_r.as_ptr(),
613            attn.x_w.as_ptr(),
614            attn.x_k.as_ptr(),
615            attn.x_v.as_ptr(),
616            attn.x_a.as_ptr(),
617            attn.x_g.as_ptr(),
618            scratch.xr.as_mut_ptr(),
619            scratch.xw.as_mut_ptr(),
620            scratch.xk.as_mut_ptr(),
621            scratch.xv.as_mut_ptr(),
622            scratch.xa.as_mut_ptr(),
623            scratch.xg.as_mut_ptr(),
624            c,
625        );
626
627        // Update prev state for next token
628        kernel::copy(
629            scratch.x_normed.as_ptr(),
630            layer_state.att_x_prev.as_mut_ptr(),
631            c,
632        );
633
634        // r/k/v projections from packed matrix (sequential for better cache)
635        // Packed layout: [r_proj (C*C), k_proj (C*C), v_proj (C*C)]
636        let proj_size = c * c;
637        kernel::gemv_avx(
638            attn.rkv_proj.as_ptr(),
639            scratch.xr.as_ptr(),
640            scratch.r.as_mut_ptr(),
641            c,
642            c,
643        );
644        kernel::gemv_avx(
645            attn.rkv_proj.as_ptr().add(proj_size),
646            scratch.xk.as_ptr(),
647            scratch.k.as_mut_ptr(),
648            c,
649            c,
650        );
651        kernel::gemv_avx(
652            attn.rkv_proj.as_ptr().add(2 * proj_size),
653            scratch.xv.as_ptr(),
654            scratch.v.as_mut_ptr(),
655            c,
656            c,
657        );
658
659        // w decay: w = exp(-sigmoid(tanh(xw @ w1) @ w2 + w0) / sqrt(e))
660        // Step 1: tmp = xw @ w1.T (D_w output)
661        kernel::gemv_avx(
662            attn.w1.as_ptr(),
663            scratch.xw.as_ptr(),
664            scratch.w_lora_tmp.as_mut_ptr(),
665            d_w,
666            c,
667        );
668        // Step 2: tanh
669        kernel::tanh_avx(
670            scratch.w_lora_tmp.as_ptr(),
671            scratch.w_lora_tmp.as_mut_ptr(),
672            d_w,
673        );
674        // Step 3: tmp @ w2.T + w0
675        kernel::gemv_avx(
676            attn.w2.as_ptr(),
677            scratch.w_lora_tmp.as_ptr(),
678            scratch.w_decay.as_mut_ptr(),
679            c,
680            d_w,
681        );
682        // Add bias w0
683        kernel::add_avx(
684            scratch.w_decay.as_ptr(),
685            attn.w0.as_ptr(),
686            scratch.w_decay.as_mut_ptr(),
687            c,
688        );
689        // Step 4: exp(-sigmoid(x) / sqrt(e))
690        let inv_sqrt_e = 1.0 / std::f32::consts::E.sqrt();
691        kernel::sigmoid_avx(scratch.w_decay.as_ptr(), scratch.w_decay.as_mut_ptr(), c);
692        kernel::exp_neg_scaled_inplace(scratch.w_decay.as_mut_ptr(), inv_sqrt_e, c);
693
694        // a = sigmoid(xa @ a1.T @ a2.T + a0)
695        kernel::gemv_avx(
696            attn.a1.as_ptr(),
697            scratch.xa.as_ptr(),
698            scratch.w_lora_tmp.as_mut_ptr(),
699            d_a,
700            c,
701        );
702        kernel::gemv_avx(
703            attn.a2.as_ptr(),
704            scratch.w_lora_tmp.as_ptr(),
705            scratch.a.as_mut_ptr(),
706            c,
707            d_a,
708        );
709        kernel::add_avx(
710            scratch.a.as_ptr(),
711            attn.a0.as_ptr(),
712            scratch.a.as_mut_ptr(),
713            c,
714        );
715        kernel::sigmoid_avx(scratch.a.as_ptr(), scratch.a.as_mut_ptr(), c);
716
717        // g = sigmoid(xg @ g1.T) @ g2.T
718        kernel::gemv_avx(
719            attn.g1.as_ptr(),
720            scratch.xg.as_ptr(),
721            scratch.w_lora_tmp.as_mut_ptr(),
722            d_g,
723            c,
724        );
725        kernel::sigmoid_avx(
726            scratch.w_lora_tmp.as_ptr(),
727            scratch.w_lora_tmp.as_mut_ptr(),
728            d_g,
729        );
730        kernel::gemv_avx(
731            attn.g2.as_ptr(),
732            scratch.w_lora_tmp.as_ptr(),
733            scratch.g.as_mut_ptr(),
734            c,
735            d_g,
736        );
737
738        // Value residual (layer > 0)
739        if layer_idx == 0 {
740            // Copy v to v_first buffer (no allocation)
741            state.v_first.copy_from(&scratch.v);
742            state.v_first_set = true;
743        } else if state.v_first_set {
744            if let (Some(v1), Some(v2), Some(v0)) = (&attn.v1, &attn.v2, &attn.v0) {
745                let d_v = self.cfg.v_low_rank;
746                // nu = sigmoid(xv @ v1.T @ v2.T + v0)
747                kernel::gemv_avx(
748                    v1.as_ptr(),
749                    scratch.xv.as_ptr(),
750                    scratch.w_lora_tmp.as_mut_ptr(),
751                    d_v,
752                    c,
753                );
754                kernel::gemv_avx(
755                    v2.as_ptr(),
756                    scratch.w_lora_tmp.as_ptr(),
757                    scratch.att_out.as_mut_ptr(), // reuse as temp
758                    c,
759                    d_v,
760                );
761                kernel::add_avx(
762                    scratch.att_out.as_ptr(),
763                    v0.as_ptr(),
764                    scratch.att_out.as_mut_ptr(),
765                    c,
766                );
767                kernel::sigmoid_avx(scratch.att_out.as_ptr(), scratch.att_out.as_mut_ptr(), c);
768                // v = v + (v_first - v) * nu
769                for i in 0..c {
770                    let nu = scratch.att_out[i];
771                    scratch.v[i] += (state.v_first[i] - scratch.v[i]) * nu;
772                }
773            }
774        }
775
776        // kk = k * k_k, then L2 normalize per head
777        kernel::mul_avx(
778            scratch.k.as_ptr(),
779            attn.k_k.as_ptr(),
780            scratch.kk.as_mut_ptr(),
781            c,
782        );
783        // Normalize per head
784        for head in 0..h {
785            let offset = head * n;
786            kernel::l2_normalize_avx(
787                scratch.kk.as_ptr().add(offset),
788                scratch.kk.as_mut_ptr().add(offset),
789                n,
790                1e-12,
791            );
792        }
793
794        // k = k * (1 + (a - 1) * k_a)
795        for i in 0..c {
796            let scale = 1.0 + (scratch.a[i] - 1.0) * attn.k_a[i];
797            scratch.k[i] *= scale;
798        }
799
800        // WKV state update: S = S*w.T - S@kk*(kk*a).T + v*k.T; y = S@r
801        kernel::rwkv7_wkv_update_avx(
802            layer_state.att_state.as_mut_ptr(),
803            scratch.w_decay.as_ptr(),
804            scratch.k.as_ptr(),
805            scratch.v.as_ptr(),
806            scratch.kk.as_ptr(),
807            scratch.a.as_ptr(),
808            scratch.r.as_ptr(),
809            scratch.y.as_mut_ptr(),
810            h,
811            n,
812        );
813
814        // Group norm
815        kernel::group_norm_avx(
816            scratch.y.as_ptr(),
817            attn.g_norm_w.as_ptr(),
818            attn.g_norm_b.as_ptr(),
819            scratch.y.as_mut_ptr(),
820            h,
821            n,
822            self.cfg.group_norm_eps,
823        );
824
825        // Add head-qk term: y += ((r * k * r_k).sum_per_head) * v
826        for head in 0..h {
827            let offset = head * n;
828            let mut alpha = 0.0f32;
829            for j in 0..n {
830                alpha += scratch.r[offset + j] * scratch.k[offset + j] * attn.r_k[head * n + j];
831            }
832            for j in 0..n {
833                scratch.y[offset + j] += alpha * scratch.v[offset + j];
834            }
835        }
836
837        // Apply gate: y = y * g
838        kernel::mul_avx(
839            scratch.y.as_ptr(),
840            scratch.g.as_ptr(),
841            scratch.y.as_mut_ptr(),
842            c,
843        );
844
845        // Output projection: att_out = o_proj @ y
846        kernel::gemv_avx(
847            attn.o_proj.as_ptr(),
848            scratch.y.as_ptr(),
849            scratch.att_out.as_mut_ptr(),
850            c,
851            c,
852        );
853    }
854
855    #[inline(always)]
856    unsafe fn ffn_forward_impl(
857        &self,
858        scratch: &mut ScratchBuffers,
859        layer_idx: usize,
860        layer_state: &mut LayerState,
861    ) {
862        let ffn = &self.blocks[layer_idx].ffn;
863        let c = self.cfg.hidden_size;
864        let i = self.cfg.intermediate_size;
865
866        // Token shift: xk = x_normed + x_k * (prev - x_normed)
867        kernel::token_shift_avx(
868            scratch.x_normed.as_ptr(),
869            layer_state.ffn_x_prev.as_ptr(),
870            ffn.x_k.as_ptr(),
871            scratch.xk.as_mut_ptr(),
872            c,
873        );
874
875        // Update prev state
876        kernel::copy(
877            scratch.x_normed.as_ptr(),
878            layer_state.ffn_x_prev.as_mut_ptr(),
879            c,
880        );
881
882        // k = relu(xk @ key_w.T)^2
883        kernel::gemv_avx(
884            ffn.key_w.as_ptr(),
885            scratch.xk.as_ptr(),
886            scratch.ffn_k.as_mut_ptr(),
887            i,
888            c,
889        );
890        kernel::relu_squared_avx(scratch.ffn_k.as_ptr(), scratch.ffn_k.as_mut_ptr(), i);
891
892        // ffn_out = k @ value_w.T
893        kernel::gemv_avx(
894            ffn.value_w.as_ptr(),
895            scratch.ffn_k.as_ptr(),
896            scratch.ffn_out.as_mut_ptr(),
897            c,
898            i,
899        );
900    }
901}
902
903#[cfg(test)]
904mod tests {
905    use super::*;
906
907    #[test]
908    fn test_config_default() {
909        let cfg = Config::default();
910        assert_eq!(cfg.vocab_size, 256);
911        assert_eq!(cfg.hidden_size, 256);
912        assert_eq!(cfg.num_layers, 12);
913        assert_eq!(cfg.num_heads, 4);
914        assert_eq!(cfg.head_dim, 64);
915    }
916}