rwkvzip/rwkv7/training/
wkv_cuda.rs

1//! FFI bindings for CUDA WKV7 kernels.
2//!
3//! This module provides Rust bindings to custom CUDA kernels for fast WKV computation.
4//! - Inference kernel: Single forward pass, no gradient support
5//! - Training kernel: Forward + backward with checkpointing for gradient computation
6//!
7//! The training kernel processes the entire sequence in a single launch per direction,
8//! avoiding the overhead of T * ~10 kernel launches per attention layer.
9
10use tch::{Device, Kind, Tensor};
11
12// Constants matching CUDA kernel compilation
13pub const HEAD_DIM: i64 = 64; // _N_ in CUDA
14pub const CHUNK_LEN: i64 = 32; // _CHUNK_LEN_ in CUDA
15
16// FFI declarations for CUDA kernels
17#[cfg(feature = "cuda_wkv")]
18extern "C" {
19    // Inference kernel (stateless, no checkpointing)
20    fn wkv7_cuda_forward(
21        bh: i32,
22        t: i32,
23        n: i32,
24        r: *const f32,
25        w: *const f32,
26        k: *const f32,
27        v: *const f32,
28        a: *const f32,
29        kk: *const f32,
30        state_in: *const f32,
31        y: *mut f32,
32        state_out: *mut f32,
33    );
34
35    // Training forward kernel (saves checkpoints for backward)
36    fn wkv7_cuda_forward_train(
37        b: i32,
38        t: i32,
39        h: i32,
40        n: i32,
41        r: *const f32,
42        w: *const f32,
43        k: *const f32,
44        v: *const f32,
45        a: *const f32,
46        b_vec: *const f32, // 'b' parameter in RWKV7 formula
47        y: *mut f32,
48        s_out: *mut f32,  // Checkpointed states
49        sa_out: *mut f32, // Saved sa values
50    );
51
52    // Training backward kernel (uses checkpoints)
53    fn wkv7_cuda_backward_train(
54        b: i32,
55        t: i32,
56        h: i32,
57        n: i32,
58        r: *const f32,
59        w: *const f32,
60        k: *const f32,
61        v: *const f32,
62        a: *const f32,
63        b_vec: *const f32,
64        dy: *const f32,
65        s_in: *const f32,
66        sa_in: *const f32,
67        dr: *mut f32,
68        dw: *mut f32,
69        dk: *mut f32,
70        dv: *mut f32,
71        da: *mut f32,
72        db: *mut f32,
73    );
74}
75
76/// Check if CUDA WKV kernel is available
77pub fn cuda_wkv_available() -> bool {
78    cfg!(feature = "cuda_wkv")
79}
80
81/// Compute WKV using custom CUDA training kernel with gradient support.
82///
83/// This function:
84/// 1. Runs forward pass saving checkpoints every CHUNK_LEN steps
85/// 2. When backward is called, uses checkpoints to compute gradients efficiently
86///
87/// Note: The RWKV7 WKV formula here differs slightly from the one in wkv_fused.rs.
88/// This kernel uses the formulation from official RWKV code:
89///   sa = sum_j(a[j] * state[j])
90///   state[j] = state[j] * w[j] + k[j] * v + sa * b[j]
91///   y = sum_j(state[j] * r[j])
92///
93/// Inputs:
94/// - r, w, k, v, a, b: (B, T, H, N) tensors on CUDA (w is raw, not exp'd)
95/// - All must have requires_grad=True for backward to work
96///
97/// Returns:
98/// - y: (B, T, H, N)
99#[cfg(feature = "cuda_wkv")]
100pub fn wkv_cuda_train(
101    r: &Tensor,
102    w: &Tensor,
103    k: &Tensor,
104    v: &Tensor,
105    a: &Tensor,
106    b: &Tensor,
107) -> Tensor {
108    let sizes = r.size();
109    let (batch, seq_len, heads, head_dim) = (sizes[0], sizes[1], sizes[2], sizes[3]);
110
111    assert_eq!(
112        head_dim, HEAD_DIM,
113        "Head dimension must be {} for CUDA kernel",
114        HEAD_DIM
115    );
116
117    // Ensure inputs are contiguous and fp32
118    let r = r.contiguous().to_kind(Kind::Float);
119    let w = w.contiguous().to_kind(Kind::Float);
120    let k = k.contiguous().to_kind(Kind::Float);
121    let v = v.contiguous().to_kind(Kind::Float);
122    let a = a.contiguous().to_kind(Kind::Float);
123    let b = b.contiguous().to_kind(Kind::Float);
124
125    // Allocate outputs
126    let y = Tensor::zeros([batch, seq_len, heads, head_dim], (Kind::Float, r.device()));
127
128    // Checkpoint storage
129    let num_chunks = (seq_len + CHUNK_LEN - 1) / CHUNK_LEN;
130    let s_checkpoints = Tensor::zeros(
131        [batch, heads, num_chunks, head_dim, head_dim],
132        (Kind::Float, r.device()),
133    );
134    let sa_saved = Tensor::zeros([batch, seq_len, heads, head_dim], (Kind::Float, r.device()));
135
136    // Get raw pointers
137    let r_ptr = r.data_ptr() as *const f32;
138    let w_ptr = w.data_ptr() as *const f32;
139    let k_ptr = k.data_ptr() as *const f32;
140    let v_ptr = v.data_ptr() as *const f32;
141    let a_ptr = a.data_ptr() as *const f32;
142    let b_ptr = b.data_ptr() as *const f32;
143    let y_ptr = y.data_ptr() as *mut f32;
144    let s_ptr = s_checkpoints.data_ptr() as *mut f32;
145    let sa_ptr = sa_saved.data_ptr() as *mut f32;
146
147    unsafe {
148        wkv7_cuda_forward_train(
149            batch as i32,
150            seq_len as i32,
151            heads as i32,
152            head_dim as i32,
153            r_ptr,
154            w_ptr,
155            k_ptr,
156            v_ptr,
157            a_ptr,
158            b_ptr,
159            y_ptr,
160            s_ptr,
161            sa_ptr,
162        );
163    }
164
165    // Synchronize
166    if let Device::Cuda(idx) = r.device() {
167        tch::Cuda::synchronize(idx as i64);
168    }
169
170    y
171}
172
173/// WKV computation without custom CUDA kernel (fallback).
174/// Uses the checkpointed approach from wkv_fused.rs
175#[cfg(not(feature = "cuda_wkv"))]
176pub fn wkv_cuda_train(
177    _r: &Tensor,
178    _w: &Tensor,
179    _k: &Tensor,
180    _v: &Tensor,
181    _a: &Tensor,
182    _b: &Tensor,
183) -> Tensor {
184    panic!("CUDA WKV kernel not available. Rebuild with CUDA toolkit or use wkv_fused.");
185}
186
187/// Fallback inference function
188#[cfg(not(feature = "cuda_wkv"))]
189pub fn wkv_cuda(
190    _r: &Tensor,
191    _k: &Tensor,
192    _v: &Tensor,
193    _w: &Tensor,
194    _a: &Tensor,
195    _kk: &Tensor,
196    _state: &Tensor,
197) -> (Tensor, Tensor) {
198    panic!("CUDA WKV kernel not available.");
199}
200
201/// Inference-only WKV (forward without gradient support)
202#[cfg(feature = "cuda_wkv")]
203pub fn wkv_cuda(
204    r: &Tensor,
205    k: &Tensor,
206    v: &Tensor,
207    w: &Tensor,
208    a: &Tensor,
209    kk: &Tensor,
210    state: &Tensor,
211) -> (Tensor, Tensor) {
212    let sizes = r.size();
213    let (b, t, h, n) = (sizes[0], sizes[1], sizes[2], sizes[3]);
214    let bh = b * h;
215
216    let r = r.reshape([bh, t, n]).contiguous();
217    let k = k.reshape([bh, t, n]).contiguous();
218    let v = v.reshape([bh, t, n]).contiguous();
219    let w = w.reshape([bh, t, n]).contiguous();
220    let a = a.reshape([bh, t, n]).contiguous();
221    let kk = kk.reshape([bh, t, n]).contiguous();
222    let state = state.reshape([bh, n, n]).contiguous();
223
224    let y = Tensor::zeros([bh, t, n], (r.kind(), r.device()));
225    let state_out = Tensor::zeros([bh, n, n], (r.kind(), r.device()));
226
227    unsafe {
228        wkv7_cuda_forward(
229            bh as i32,
230            t as i32,
231            n as i32,
232            r.data_ptr() as *const f32,
233            w.data_ptr() as *const f32,
234            k.data_ptr() as *const f32,
235            v.data_ptr() as *const f32,
236            a.data_ptr() as *const f32,
237            kk.data_ptr() as *const f32,
238            state.data_ptr() as *const f32,
239            y.data_ptr() as *mut f32,
240            state_out.data_ptr() as *mut f32,
241        );
242    }
243
244    if let Device::Cuda(idx) = r.device() {
245        tch::Cuda::synchronize(idx as i64);
246    }
247
248    (y.reshape([b, t, h, n]), state_out.reshape([b, h, n, n]))
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_cuda_wkv_available() {
257        let available = cuda_wkv_available();
258        println!("CUDA WKV available: {}", available);
259    }
260
261    #[test]
262    fn test_constants() {
263        assert_eq!(HEAD_DIM, 64);
264        assert_eq!(CHUNK_LEN, 32);
265    }
266}