rwkvzip/rwkv7/
tensor.rs

1//! Simple aligned tensor types for SIMD operations.
2//!
3//! These are minimal, no-frills tensor implementations designed for:
4//! - Aligned memory for AVX2 (32-byte alignment)
5//! - Direct access to underlying data
6//! - Zero-copy views for weights
7
8use std::alloc::{alloc_zeroed, dealloc, Layout};
9use std::ops::{Index, IndexMut};
10use std::ptr::NonNull;
11
12/// 32-byte alignment for AVX2
13const ALIGNMENT: usize = 32;
14
15/// Owned 1D tensor with aligned memory.
16#[repr(C)]
17pub struct Tensor1D {
18    data: NonNull<f32>,
19    len: usize,
20}
21
22impl Tensor1D {
23    /// Create a new zero-initialized tensor.
24    pub fn zeros(len: usize) -> Self {
25        let layout = Layout::from_size_align(len * 4, ALIGNMENT).expect("Invalid layout");
26
27        let ptr = unsafe { alloc_zeroed(layout) as *mut f32 };
28        let data = NonNull::new(ptr).expect("Allocation failed");
29
30        Self { data, len }
31    }
32
33    /// Create from an existing `Vec<f32>` (may copy if not aligned).
34    pub fn from_vec(v: Vec<f32>) -> Self {
35        let mut t = Self::zeros(v.len());
36        t.as_mut_slice().copy_from_slice(&v);
37        t
38    }
39
40    #[inline]
41    pub fn len(&self) -> usize {
42        self.len
43    }
44
45    #[inline]
46    pub fn is_empty(&self) -> bool {
47        self.len == 0
48    }
49
50    #[inline]
51    pub fn as_ptr(&self) -> *const f32 {
52        self.data.as_ptr()
53    }
54
55    #[inline]
56    pub fn as_mut_ptr(&mut self) -> *mut f32 {
57        self.data.as_ptr()
58    }
59
60    #[inline]
61    pub fn as_slice(&self) -> &[f32] {
62        unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.len) }
63    }
64
65    #[inline]
66    pub fn as_mut_slice(&mut self) -> &mut [f32] {
67        unsafe { std::slice::from_raw_parts_mut(self.data.as_ptr(), self.len) }
68    }
69
70    /// Fill with zeros.
71    #[inline]
72    pub fn zero(&mut self) {
73        unsafe {
74            std::ptr::write_bytes(self.data.as_ptr(), 0, self.len);
75        }
76    }
77
78    /// Copy from another tensor.
79    #[inline]
80    pub fn copy_from(&mut self, other: &Tensor1D) {
81        debug_assert_eq!(self.len, other.len);
82        self.as_mut_slice().copy_from_slice(other.as_slice());
83    }
84
85    /// Copy from slice.
86    #[inline]
87    pub fn copy_from_slice(&mut self, slice: &[f32]) {
88        debug_assert_eq!(self.len, slice.len());
89        self.as_mut_slice().copy_from_slice(slice);
90    }
91}
92
93impl Clone for Tensor1D {
94    fn clone(&self) -> Self {
95        let mut new = Self::zeros(self.len);
96        new.as_mut_slice().copy_from_slice(self.as_slice());
97        new
98    }
99}
100
101impl Drop for Tensor1D {
102    fn drop(&mut self) {
103        let layout = Layout::from_size_align(self.len * 4, ALIGNMENT).expect("Invalid layout");
104        unsafe {
105            dealloc(self.data.as_ptr() as *mut u8, layout);
106        }
107    }
108}
109
110// Safety: Tensor1D owns its data
111unsafe impl Send for Tensor1D {}
112unsafe impl Sync for Tensor1D {}
113
114impl Index<usize> for Tensor1D {
115    type Output = f32;
116
117    #[inline]
118    fn index(&self, i: usize) -> &f32 {
119        debug_assert!(i < self.len);
120        unsafe { &*self.data.as_ptr().add(i) }
121    }
122}
123
124impl IndexMut<usize> for Tensor1D {
125    #[inline]
126    fn index_mut(&mut self, i: usize) -> &mut f32 {
127        debug_assert!(i < self.len);
128        unsafe { &mut *self.data.as_ptr().add(i) }
129    }
130}
131
132/// Owned 2D tensor with aligned memory (row-major).
133#[repr(C)]
134pub struct Tensor2D {
135    data: NonNull<f32>,
136    rows: usize,
137    cols: usize,
138    stride: usize, // stride in elements (rounded up for alignment)
139}
140
141impl Tensor2D {
142    /// Create a new zero-initialized 2D tensor.
143    pub fn zeros(rows: usize, cols: usize) -> Self {
144        // Pad cols to 8 for AVX2 alignment
145        let stride = (cols + 7) & !7;
146        let total = rows * stride;
147
148        let layout = Layout::from_size_align(total * 4, ALIGNMENT).expect("Invalid layout");
149
150        let ptr = unsafe { alloc_zeroed(layout) as *mut f32 };
151        let data = NonNull::new(ptr).expect("Allocation failed");
152
153        Self {
154            data,
155            rows,
156            cols,
157            stride,
158        }
159    }
160
161    /// Create from Vec with shape.
162    pub fn from_vec(v: Vec<f32>, rows: usize, cols: usize) -> Self {
163        assert_eq!(v.len(), rows * cols);
164        let mut t = Self::zeros(rows, cols);
165
166        // Copy row by row to handle stride
167        for r in 0..rows {
168            let src_start = r * cols;
169            let src_end = src_start + cols;
170            t.row_mut(r).copy_from_slice(&v[src_start..src_end]);
171        }
172        t
173    }
174
175    #[inline]
176    pub fn rows(&self) -> usize {
177        self.rows
178    }
179
180    #[inline]
181    pub fn cols(&self) -> usize {
182        self.cols
183    }
184
185    #[inline]
186    pub fn stride(&self) -> usize {
187        self.stride
188    }
189
190    #[inline]
191    pub fn as_ptr(&self) -> *const f32 {
192        self.data.as_ptr()
193    }
194
195    #[inline]
196    pub fn as_mut_ptr(&mut self) -> *mut f32 {
197        self.data.as_ptr()
198    }
199
200    /// Get a row slice.
201    #[inline]
202    pub fn row(&self, r: usize) -> &[f32] {
203        debug_assert!(r < self.rows);
204        unsafe {
205            let ptr = self.data.as_ptr().add(r * self.stride);
206            std::slice::from_raw_parts(ptr, self.cols)
207        }
208    }
209
210    /// Get a mutable row slice.
211    #[inline]
212    pub fn row_mut(&mut self, r: usize) -> &mut [f32] {
213        debug_assert!(r < self.rows);
214        unsafe {
215            let ptr = self.data.as_ptr().add(r * self.stride);
216            std::slice::from_raw_parts_mut(ptr, self.cols)
217        }
218    }
219
220    /// Get raw row pointer (includes stride padding).
221    #[inline]
222    pub fn row_ptr(&self, r: usize) -> *const f32 {
223        debug_assert!(r < self.rows);
224        unsafe { self.data.as_ptr().add(r * self.stride) }
225    }
226
227    /// Get raw mutable row pointer.
228    #[inline]
229    pub fn row_ptr_mut(&mut self, r: usize) -> *mut f32 {
230        debug_assert!(r < self.rows);
231        unsafe { self.data.as_ptr().add(r * self.stride) }
232    }
233
234    /// Fill with zeros.
235    pub fn zero(&mut self) {
236        let total = self.rows * self.stride;
237        unsafe {
238            std::ptr::write_bytes(self.data.as_ptr(), 0, total);
239        }
240    }
241}
242
243impl Clone for Tensor2D {
244    fn clone(&self) -> Self {
245        let total = self.rows * self.stride;
246        let layout = Layout::from_size_align(total * 4, ALIGNMENT).expect("Invalid layout");
247
248        let ptr = unsafe { alloc_zeroed(layout) as *mut f32 };
249        let data = NonNull::new(ptr).expect("Allocation failed");
250
251        unsafe {
252            std::ptr::copy_nonoverlapping(self.data.as_ptr(), ptr, total);
253        }
254
255        Self {
256            data,
257            rows: self.rows,
258            cols: self.cols,
259            stride: self.stride,
260        }
261    }
262}
263
264impl Drop for Tensor2D {
265    fn drop(&mut self) {
266        let total = self.rows * self.stride;
267        let layout = Layout::from_size_align(total * 4, ALIGNMENT).expect("Invalid layout");
268        unsafe {
269            dealloc(self.data.as_ptr() as *mut u8, layout);
270        }
271    }
272}
273
274// Safety: Tensor2D owns its data
275unsafe impl Send for Tensor2D {}
276unsafe impl Sync for Tensor2D {}
277
278/// View into external f32 data (for weights).
279#[derive(Clone, Copy)]
280pub struct TensorView1D<'a> {
281    data: &'a [f32],
282}
283
284impl<'a> TensorView1D<'a> {
285    #[inline]
286    pub fn new(data: &'a [f32]) -> Self {
287        Self { data }
288    }
289
290    #[inline]
291    pub fn len(&self) -> usize {
292        self.data.len()
293    }
294
295    #[inline]
296    pub fn is_empty(&self) -> bool {
297        self.data.is_empty()
298    }
299
300    #[inline]
301    pub fn as_ptr(&self) -> *const f32 {
302        self.data.as_ptr()
303    }
304
305    #[inline]
306    pub fn as_slice(&self) -> &[f32] {
307        self.data
308    }
309}
310
311impl<'a> Index<usize> for TensorView1D<'a> {
312    type Output = f32;
313
314    #[inline]
315    fn index(&self, i: usize) -> &f32 {
316        &self.data[i]
317    }
318}
319
320/// View into external f32 data (for weights), row-major.
321#[derive(Clone, Copy)]
322pub struct TensorView2D<'a> {
323    data: &'a [f32],
324    rows: usize,
325    cols: usize,
326}
327
328impl<'a> TensorView2D<'a> {
329    #[inline]
330    pub fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
331        debug_assert_eq!(data.len(), rows * cols);
332        Self { data, rows, cols }
333    }
334
335    #[inline]
336    pub fn rows(&self) -> usize {
337        self.rows
338    }
339
340    #[inline]
341    pub fn cols(&self) -> usize {
342        self.cols
343    }
344
345    #[inline]
346    pub fn as_ptr(&self) -> *const f32 {
347        self.data.as_ptr()
348    }
349
350    #[inline]
351    pub fn row(&self, r: usize) -> &[f32] {
352        debug_assert!(r < self.rows);
353        let start = r * self.cols;
354        &self.data[start..start + self.cols]
355    }
356
357    #[inline]
358    pub fn row_ptr(&self, r: usize) -> *const f32 {
359        debug_assert!(r < self.rows);
360        unsafe { self.data.as_ptr().add(r * self.cols) }
361    }
362
363    /// Transpose view (returns new TensorView with swapped dims).
364    /// Note: This is a logical transpose - data is still row-major of original.
365    /// Use only for matmuls that handle transposed right operand.
366    pub fn t(&self) -> TransposedView2D<'a> {
367        TransposedView2D {
368            data: self.data,
369            rows: self.cols, // swapped
370            cols: self.rows, // swapped
371            orig_cols: self.cols,
372        }
373    }
374}
375
376/// Transposed view (for efficient transpose-multiply).
377#[derive(Clone, Copy)]
378pub struct TransposedView2D<'a> {
379    data: &'a [f32],
380    rows: usize,
381    cols: usize,
382    orig_cols: usize,
383}
384
385impl<'a> TransposedView2D<'a> {
386    #[inline]
387    pub fn rows(&self) -> usize {
388        self.rows
389    }
390
391    #[inline]
392    pub fn cols(&self) -> usize {
393        self.cols
394    }
395
396    /// Get element at (r, c) in transposed view.
397    #[inline]
398    pub fn get(&self, r: usize, c: usize) -> f32 {
399        // In transposed view, (r, c) maps to original (c, r)
400        self.data[c * self.orig_cols + r]
401    }
402
403    /// Get original row (which is a column in transposed view).
404    #[inline]
405    pub fn orig_row(&self, r: usize) -> &[f32] {
406        let start = r * self.orig_cols;
407        &self.data[start..start + self.orig_cols]
408    }
409}