1use tch::{Device, Kind, Tensor};
11
12pub const HEAD_DIM: i64 = 64; pub const CHUNK_LEN: i64 = 32; #[cfg(feature = "cuda_wkv")]
18extern "C" {
19 fn wkv7_cuda_forward(
21 bh: i32,
22 t: i32,
23 n: i32,
24 r: *const f32,
25 w: *const f32,
26 k: *const f32,
27 v: *const f32,
28 a: *const f32,
29 kk: *const f32,
30 state_in: *const f32,
31 y: *mut f32,
32 state_out: *mut f32,
33 );
34
35 fn wkv7_cuda_forward_train(
37 b: i32,
38 t: i32,
39 h: i32,
40 n: i32,
41 r: *const f32,
42 w: *const f32,
43 k: *const f32,
44 v: *const f32,
45 a: *const f32,
46 b_vec: *const f32, y: *mut f32,
48 s_out: *mut f32, sa_out: *mut f32, );
51
52 fn wkv7_cuda_backward_train(
54 b: i32,
55 t: i32,
56 h: i32,
57 n: i32,
58 r: *const f32,
59 w: *const f32,
60 k: *const f32,
61 v: *const f32,
62 a: *const f32,
63 b_vec: *const f32,
64 dy: *const f32,
65 s_in: *const f32,
66 sa_in: *const f32,
67 dr: *mut f32,
68 dw: *mut f32,
69 dk: *mut f32,
70 dv: *mut f32,
71 da: *mut f32,
72 db: *mut f32,
73 );
74}
75
76pub fn cuda_wkv_available() -> bool {
78 cfg!(feature = "cuda_wkv")
79}
80
81#[cfg(feature = "cuda_wkv")]
100pub fn wkv_cuda_train(
101 r: &Tensor,
102 w: &Tensor,
103 k: &Tensor,
104 v: &Tensor,
105 a: &Tensor,
106 b: &Tensor,
107) -> Tensor {
108 let sizes = r.size();
109 let (batch, seq_len, heads, head_dim) = (sizes[0], sizes[1], sizes[2], sizes[3]);
110
111 assert_eq!(
112 head_dim, HEAD_DIM,
113 "Head dimension must be {} for CUDA kernel",
114 HEAD_DIM
115 );
116
117 let r = r.contiguous().to_kind(Kind::Float);
119 let w = w.contiguous().to_kind(Kind::Float);
120 let k = k.contiguous().to_kind(Kind::Float);
121 let v = v.contiguous().to_kind(Kind::Float);
122 let a = a.contiguous().to_kind(Kind::Float);
123 let b = b.contiguous().to_kind(Kind::Float);
124
125 let y = Tensor::zeros([batch, seq_len, heads, head_dim], (Kind::Float, r.device()));
127
128 let num_chunks = (seq_len + CHUNK_LEN - 1) / CHUNK_LEN;
130 let s_checkpoints = Tensor::zeros(
131 [batch, heads, num_chunks, head_dim, head_dim],
132 (Kind::Float, r.device()),
133 );
134 let sa_saved = Tensor::zeros([batch, seq_len, heads, head_dim], (Kind::Float, r.device()));
135
136 let r_ptr = r.data_ptr() as *const f32;
138 let w_ptr = w.data_ptr() as *const f32;
139 let k_ptr = k.data_ptr() as *const f32;
140 let v_ptr = v.data_ptr() as *const f32;
141 let a_ptr = a.data_ptr() as *const f32;
142 let b_ptr = b.data_ptr() as *const f32;
143 let y_ptr = y.data_ptr() as *mut f32;
144 let s_ptr = s_checkpoints.data_ptr() as *mut f32;
145 let sa_ptr = sa_saved.data_ptr() as *mut f32;
146
147 unsafe {
148 wkv7_cuda_forward_train(
149 batch as i32,
150 seq_len as i32,
151 heads as i32,
152 head_dim as i32,
153 r_ptr,
154 w_ptr,
155 k_ptr,
156 v_ptr,
157 a_ptr,
158 b_ptr,
159 y_ptr,
160 s_ptr,
161 sa_ptr,
162 );
163 }
164
165 if let Device::Cuda(idx) = r.device() {
167 tch::Cuda::synchronize(idx as i64);
168 }
169
170 y
171}
172
173#[cfg(not(feature = "cuda_wkv"))]
176pub fn wkv_cuda_train(
177 _r: &Tensor,
178 _w: &Tensor,
179 _k: &Tensor,
180 _v: &Tensor,
181 _a: &Tensor,
182 _b: &Tensor,
183) -> Tensor {
184 panic!("CUDA WKV kernel not available. Rebuild with CUDA toolkit or use wkv_fused.");
185}
186
187#[cfg(not(feature = "cuda_wkv"))]
189pub fn wkv_cuda(
190 _r: &Tensor,
191 _k: &Tensor,
192 _v: &Tensor,
193 _w: &Tensor,
194 _a: &Tensor,
195 _kk: &Tensor,
196 _state: &Tensor,
197) -> (Tensor, Tensor) {
198 panic!("CUDA WKV kernel not available.");
199}
200
201#[cfg(feature = "cuda_wkv")]
203pub fn wkv_cuda(
204 r: &Tensor,
205 k: &Tensor,
206 v: &Tensor,
207 w: &Tensor,
208 a: &Tensor,
209 kk: &Tensor,
210 state: &Tensor,
211) -> (Tensor, Tensor) {
212 let sizes = r.size();
213 let (b, t, h, n) = (sizes[0], sizes[1], sizes[2], sizes[3]);
214 let bh = b * h;
215
216 let r = r.reshape([bh, t, n]).contiguous();
217 let k = k.reshape([bh, t, n]).contiguous();
218 let v = v.reshape([bh, t, n]).contiguous();
219 let w = w.reshape([bh, t, n]).contiguous();
220 let a = a.reshape([bh, t, n]).contiguous();
221 let kk = kk.reshape([bh, t, n]).contiguous();
222 let state = state.reshape([bh, n, n]).contiguous();
223
224 let y = Tensor::zeros([bh, t, n], (r.kind(), r.device()));
225 let state_out = Tensor::zeros([bh, n, n], (r.kind(), r.device()));
226
227 unsafe {
228 wkv7_cuda_forward(
229 bh as i32,
230 t as i32,
231 n as i32,
232 r.data_ptr() as *const f32,
233 w.data_ptr() as *const f32,
234 k.data_ptr() as *const f32,
235 v.data_ptr() as *const f32,
236 a.data_ptr() as *const f32,
237 kk.data_ptr() as *const f32,
238 state.data_ptr() as *const f32,
239 y.data_ptr() as *mut f32,
240 state_out.data_ptr() as *mut f32,
241 );
242 }
243
244 if let Device::Cuda(idx) = r.device() {
245 tch::Cuda::synchronize(idx as i64);
246 }
247
248 (y.reshape([b, t, h, n]), state_out.reshape([b, h, n, n]))
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_cuda_wkv_available() {
257 let available = cuda_wkv_available();
258 println!("CUDA WKV available: {}", available);
259 }
260
261 #[test]
262 fn test_constants() {
263 assert_eq!(HEAD_DIM, 64);
264 assert_eq!(CHUNK_LEN, 32);
265 }
266}