infotheory/backends/rwkvzip/rwkv7/
tensor.rs

1//! Simple aligned tensor types for SIMD operations.
2//!
3//! These are minimal, no-frills tensor implementations designed for:
4//! - 32-byte aligned memory for portable SIMD kernels
5//! - Direct access to underlying data
6//! - Zero-copy views for weights
7
8use std::alloc::{Layout, alloc_zeroed, dealloc};
9use std::mem::size_of;
10use std::ops::{Index, IndexMut};
11use std::ptr::NonNull;
12
13/// 32-byte alignment for SIMD-friendly access.
14const ALIGNMENT: usize = 32;
15
16#[inline]
17fn dangling_aligned_f32() -> NonNull<f32> {
18    debug_assert_eq!(ALIGNMENT % std::mem::align_of::<f32>(), 0);
19    NonNull::new(ALIGNMENT as *mut u8)
20        .expect("aligned dangling pointer must be non-null")
21        .cast()
22}
23
24#[inline]
25fn layout_for_f32_elems(len: usize) -> Layout {
26    let bytes = len
27        .checked_mul(size_of::<f32>())
28        .expect("tensor allocation overflow");
29    Layout::from_size_align(bytes, ALIGNMENT).expect("Invalid layout")
30}
31
32#[inline]
33fn alloc_f32_buffer(len: usize) -> NonNull<f32> {
34    if len == 0 {
35        return dangling_aligned_f32();
36    }
37    let layout = layout_for_f32_elems(len);
38    let ptr = unsafe { alloc_zeroed(layout) };
39    NonNull::new(ptr).expect("Allocation failed").cast()
40}
41
42#[inline]
43unsafe fn dealloc_f32_buffer(ptr: NonNull<f32>, len: usize) {
44    if len == 0 {
45        return;
46    }
47    let layout = layout_for_f32_elems(len);
48    unsafe {
49        dealloc(ptr.as_ptr() as *mut u8, layout);
50    }
51}
52
53#[inline]
54fn padded_stride(cols: usize) -> usize {
55    cols.checked_add(7).expect("tensor stride overflow") & !7
56}
57
58/// Owned 1D tensor with aligned memory.
59#[repr(C)]
60pub struct Tensor1D {
61    data: NonNull<f32>,
62    len: usize,
63}
64
65impl Tensor1D {
66    /// Create a new zero-initialized tensor.
67    pub fn zeros(len: usize) -> Self {
68        Self {
69            data: alloc_f32_buffer(len),
70            len,
71        }
72    }
73
74    /// Create from an existing `Vec<f32>` (may copy if not aligned).
75    pub fn from_vec(v: Vec<f32>) -> Self {
76        let mut t = Self::zeros(v.len());
77        t.as_mut_slice().copy_from_slice(&v);
78        t
79    }
80
81    #[inline]
82    /// Number of logical elements.
83    pub fn len(&self) -> usize {
84        self.len
85    }
86
87    #[inline]
88    /// Returns `true` when `len() == 0`.
89    pub fn is_empty(&self) -> bool {
90        self.len == 0
91    }
92
93    #[inline]
94    /// Raw pointer to the aligned backing buffer.
95    pub fn as_ptr(&self) -> *const f32 {
96        self.data.as_ptr()
97    }
98
99    #[inline]
100    /// Mutable raw pointer to the aligned backing buffer.
101    pub fn as_mut_ptr(&mut self) -> *mut f32 {
102        self.data.as_ptr()
103    }
104
105    #[inline]
106    /// Immutable slice over logical elements.
107    pub fn as_slice(&self) -> &[f32] {
108        unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.len) }
109    }
110
111    #[inline]
112    /// Mutable slice over logical elements.
113    pub fn as_mut_slice(&mut self) -> &mut [f32] {
114        unsafe { std::slice::from_raw_parts_mut(self.data.as_ptr(), self.len) }
115    }
116
117    /// Fill with zeros.
118    #[inline]
119    pub fn zero(&mut self) {
120        unsafe {
121            std::ptr::write_bytes(self.data.as_ptr(), 0, self.len);
122        }
123    }
124
125    /// Copy from another tensor.
126    #[inline]
127    pub fn copy_from(&mut self, other: &Tensor1D) {
128        debug_assert_eq!(self.len, other.len);
129        self.as_mut_slice().copy_from_slice(other.as_slice());
130    }
131
132    /// Copy from slice.
133    #[inline]
134    pub fn copy_from_slice(&mut self, slice: &[f32]) {
135        debug_assert_eq!(self.len, slice.len());
136        self.as_mut_slice().copy_from_slice(slice);
137    }
138}
139
140impl Clone for Tensor1D {
141    fn clone(&self) -> Self {
142        let mut new = Self::zeros(self.len);
143        new.as_mut_slice().copy_from_slice(self.as_slice());
144        new
145    }
146}
147
148impl Drop for Tensor1D {
149    fn drop(&mut self) {
150        unsafe {
151            dealloc_f32_buffer(self.data, self.len);
152        }
153    }
154}
155
156// Safety: Tensor1D owns its data
157unsafe impl Send for Tensor1D {}
158unsafe impl Sync for Tensor1D {}
159
160impl Index<usize> for Tensor1D {
161    type Output = f32;
162
163    #[inline]
164    fn index(&self, i: usize) -> &f32 {
165        debug_assert!(i < self.len);
166        unsafe { &*self.data.as_ptr().add(i) }
167    }
168}
169
170impl IndexMut<usize> for Tensor1D {
171    #[inline]
172    fn index_mut(&mut self, i: usize) -> &mut f32 {
173        debug_assert!(i < self.len);
174        unsafe { &mut *self.data.as_ptr().add(i) }
175    }
176}
177
178/// Owned 2D tensor with aligned memory (row-major).
179#[repr(C)]
180pub struct Tensor2D {
181    data: NonNull<f32>,
182    rows: usize,
183    cols: usize,
184    stride: usize, // stride in elements (rounded up for alignment)
185}
186
187impl Tensor2D {
188    /// Create a new zero-initialized 2D tensor.
189    pub fn zeros(rows: usize, cols: usize) -> Self {
190        // Pad cols to a multiple of 8 f32 lanes.
191        let stride = padded_stride(cols);
192        let total = rows
193            .checked_mul(stride)
194            .expect("tensor allocation overflow");
195
196        Self {
197            data: alloc_f32_buffer(total),
198            rows,
199            cols,
200            stride,
201        }
202    }
203
204    /// Create from Vec with shape.
205    pub fn from_vec(v: Vec<f32>, rows: usize, cols: usize) -> Self {
206        assert_eq!(v.len(), rows * cols);
207        let mut t = Self::zeros(rows, cols);
208
209        // Copy row by row to handle stride
210        for r in 0..rows {
211            let src_start = r * cols;
212            let src_end = src_start + cols;
213            t.row_mut(r).copy_from_slice(&v[src_start..src_end]);
214        }
215        t
216    }
217
218    #[inline]
219    /// Number of matrix rows.
220    pub fn rows(&self) -> usize {
221        self.rows
222    }
223
224    #[inline]
225    /// Number of logical columns (excluding stride padding).
226    pub fn cols(&self) -> usize {
227        self.cols
228    }
229
230    #[inline]
231    /// Row stride in elements (includes alignment padding).
232    pub fn stride(&self) -> usize {
233        self.stride
234    }
235
236    #[inline]
237    /// Raw pointer to matrix storage.
238    pub fn as_ptr(&self) -> *const f32 {
239        self.data.as_ptr()
240    }
241
242    #[inline]
243    /// Mutable raw pointer to matrix storage.
244    pub fn as_mut_ptr(&mut self) -> *mut f32 {
245        self.data.as_ptr()
246    }
247
248    /// Get a row slice.
249    #[inline]
250    pub fn row(&self, r: usize) -> &[f32] {
251        debug_assert!(r < self.rows);
252        unsafe {
253            let ptr = self.data.as_ptr().add(r * self.stride);
254            std::slice::from_raw_parts(ptr, self.cols)
255        }
256    }
257
258    /// Get a mutable row slice.
259    #[inline]
260    pub fn row_mut(&mut self, r: usize) -> &mut [f32] {
261        debug_assert!(r < self.rows);
262        unsafe {
263            let ptr = self.data.as_ptr().add(r * self.stride);
264            std::slice::from_raw_parts_mut(ptr, self.cols)
265        }
266    }
267
268    /// Get raw row pointer (includes stride padding).
269    #[inline]
270    pub fn row_ptr(&self, r: usize) -> *const f32 {
271        debug_assert!(r < self.rows);
272        unsafe { self.data.as_ptr().add(r * self.stride) }
273    }
274
275    /// Get raw mutable row pointer.
276    #[inline]
277    pub fn row_ptr_mut(&mut self, r: usize) -> *mut f32 {
278        debug_assert!(r < self.rows);
279        unsafe { self.data.as_ptr().add(r * self.stride) }
280    }
281
282    /// Fill with zeros.
283    pub fn zero(&mut self) {
284        let total = self
285            .rows
286            .checked_mul(self.stride)
287            .expect("tensor allocation overflow");
288        unsafe {
289            std::ptr::write_bytes(self.data.as_ptr(), 0, total);
290        }
291    }
292}
293
294impl Clone for Tensor2D {
295    fn clone(&self) -> Self {
296        let total = self
297            .rows
298            .checked_mul(self.stride)
299            .expect("tensor allocation overflow");
300        let data = alloc_f32_buffer(total);
301
302        unsafe {
303            std::ptr::copy_nonoverlapping(self.data.as_ptr(), data.as_ptr(), total);
304        }
305
306        Self {
307            data,
308            rows: self.rows,
309            cols: self.cols,
310            stride: self.stride,
311        }
312    }
313}
314
315impl Drop for Tensor2D {
316    fn drop(&mut self) {
317        let total = self
318            .rows
319            .checked_mul(self.stride)
320            .expect("tensor allocation overflow");
321        unsafe {
322            dealloc_f32_buffer(self.data, total);
323        }
324    }
325}
326
327// Safety: Tensor2D owns its data
328unsafe impl Send for Tensor2D {}
329unsafe impl Sync for Tensor2D {}
330
331/// View into external f32 data (for weights).
332#[derive(Clone, Copy)]
333pub struct TensorView1D<'a> {
334    data: &'a [f32],
335}
336
337impl<'a> TensorView1D<'a> {
338    #[inline]
339    /// Wrap an immutable 1D slice.
340    pub fn new(data: &'a [f32]) -> Self {
341        Self { data }
342    }
343
344    #[inline]
345    /// Number of elements in the view.
346    pub fn len(&self) -> usize {
347        self.data.len()
348    }
349
350    #[inline]
351    /// Returns `true` when the view is empty.
352    pub fn is_empty(&self) -> bool {
353        self.data.is_empty()
354    }
355
356    #[inline]
357    /// Raw pointer to the first element.
358    pub fn as_ptr(&self) -> *const f32 {
359        self.data.as_ptr()
360    }
361
362    #[inline]
363    /// Borrow the underlying immutable slice.
364    pub fn as_slice(&self) -> &[f32] {
365        self.data
366    }
367}
368
369impl<'a> Index<usize> for TensorView1D<'a> {
370    type Output = f32;
371
372    #[inline]
373    fn index(&self, i: usize) -> &f32 {
374        &self.data[i]
375    }
376}
377
378/// View into external f32 data (for weights), row-major.
379#[derive(Clone, Copy)]
380pub struct TensorView2D<'a> {
381    data: &'a [f32],
382    rows: usize,
383    cols: usize,
384}
385
386impl<'a> TensorView2D<'a> {
387    #[inline]
388    /// Wrap row-major data with explicit `(rows, cols)` logical shape.
389    pub fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
390        debug_assert_eq!(data.len(), rows * cols);
391        Self { data, rows, cols }
392    }
393
394    #[inline]
395    /// Number of rows in the view.
396    pub fn rows(&self) -> usize {
397        self.rows
398    }
399
400    #[inline]
401    /// Number of columns in the view.
402    pub fn cols(&self) -> usize {
403        self.cols
404    }
405
406    #[inline]
407    /// Raw pointer to the first element.
408    pub fn as_ptr(&self) -> *const f32 {
409        self.data.as_ptr()
410    }
411
412    #[inline]
413    /// Borrow row `r`.
414    pub fn row(&self, r: usize) -> &[f32] {
415        debug_assert!(r < self.rows);
416        let start = r * self.cols;
417        &self.data[start..start + self.cols]
418    }
419
420    #[inline]
421    /// Pointer to the first element of row `r`.
422    pub fn row_ptr(&self, r: usize) -> *const f32 {
423        debug_assert!(r < self.rows);
424        unsafe { self.data.as_ptr().add(r * self.cols) }
425    }
426
427    /// Transpose view (returns new TensorView with swapped dims).
428    /// Note: This is a logical transpose - data is still row-major of original.
429    /// Use only for matmuls that handle transposed right operand.
430    pub fn t(&self) -> TransposedView2D<'a> {
431        TransposedView2D {
432            data: self.data,
433            rows: self.cols, // swapped
434            cols: self.rows, // swapped
435            orig_cols: self.cols,
436        }
437    }
438}
439
440/// Transposed view (for efficient transpose-multiply).
441#[derive(Clone, Copy)]
442pub struct TransposedView2D<'a> {
443    data: &'a [f32],
444    rows: usize,
445    cols: usize,
446    orig_cols: usize,
447}
448
449impl<'a> TransposedView2D<'a> {
450    #[inline]
451    pub fn rows(&self) -> usize {
452        self.rows
453    }
454
455    #[inline]
456    pub fn cols(&self) -> usize {
457        self.cols
458    }
459
460    /// Get element at (r, c) in transposed view.
461    #[inline]
462    pub fn get(&self, r: usize, c: usize) -> f32 {
463        // In transposed view, (r, c) maps to original (c, r)
464        self.data[c * self.orig_cols + r]
465    }
466
467    /// Get original row (which is a column in transposed view).
468    #[inline]
469    pub fn orig_row(&self, r: usize) -> &[f32] {
470        let start = r * self.orig_cols;
471        &self.data[start..start + self.orig_cols]
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    #[test]
480    fn zero_len_tensor1d_uses_aligned_non_allocating_sentinel() {
481        let mut t = Tensor1D::zeros(0);
482        assert_eq!(t.len(), 0);
483        assert!(t.is_empty());
484        assert!(t.as_slice().is_empty());
485        assert!(t.as_mut_slice().is_empty());
486        assert_eq!((t.as_ptr() as usize) % ALIGNMENT, 0);
487        t.zero();
488    }
489
490    #[test]
491    fn zero_sized_tensor2d_is_safe() {
492        let mut t = Tensor2D::zeros(3, 0);
493        assert_eq!(t.rows(), 3);
494        assert_eq!(t.cols(), 0);
495        assert_eq!(t.stride(), 0);
496        assert_eq!((t.as_ptr() as usize) % ALIGNMENT, 0);
497        for row in 0..t.rows() {
498            assert!(t.row(row).is_empty());
499            assert!(t.row_mut(row).is_empty());
500        }
501        t.zero();
502    }
503}