infotheory/backends/rwkvzip/rwkv7/
tensor.rs1use std::alloc::{Layout, alloc_zeroed, dealloc};
9use std::mem::size_of;
10use std::ops::{Index, IndexMut};
11use std::ptr::NonNull;
12
13const ALIGNMENT: usize = 32;
15
16#[inline]
17fn dangling_aligned_f32() -> NonNull<f32> {
18 debug_assert_eq!(ALIGNMENT % std::mem::align_of::<f32>(), 0);
19 NonNull::new(ALIGNMENT as *mut u8)
20 .expect("aligned dangling pointer must be non-null")
21 .cast()
22}
23
24#[inline]
25fn layout_for_f32_elems(len: usize) -> Layout {
26 let bytes = len
27 .checked_mul(size_of::<f32>())
28 .expect("tensor allocation overflow");
29 Layout::from_size_align(bytes, ALIGNMENT).expect("Invalid layout")
30}
31
32#[inline]
33fn alloc_f32_buffer(len: usize) -> NonNull<f32> {
34 if len == 0 {
35 return dangling_aligned_f32();
36 }
37 let layout = layout_for_f32_elems(len);
38 let ptr = unsafe { alloc_zeroed(layout) };
39 NonNull::new(ptr).expect("Allocation failed").cast()
40}
41
42#[inline]
43unsafe fn dealloc_f32_buffer(ptr: NonNull<f32>, len: usize) {
44 if len == 0 {
45 return;
46 }
47 let layout = layout_for_f32_elems(len);
48 unsafe {
49 dealloc(ptr.as_ptr() as *mut u8, layout);
50 }
51}
52
53#[inline]
54fn padded_stride(cols: usize) -> usize {
55 cols.checked_add(7).expect("tensor stride overflow") & !7
56}
57
58#[repr(C)]
60pub struct Tensor1D {
61 data: NonNull<f32>,
62 len: usize,
63}
64
65impl Tensor1D {
66 pub fn zeros(len: usize) -> Self {
68 Self {
69 data: alloc_f32_buffer(len),
70 len,
71 }
72 }
73
74 pub fn from_vec(v: Vec<f32>) -> Self {
76 let mut t = Self::zeros(v.len());
77 t.as_mut_slice().copy_from_slice(&v);
78 t
79 }
80
81 #[inline]
82 pub fn len(&self) -> usize {
84 self.len
85 }
86
87 #[inline]
88 pub fn is_empty(&self) -> bool {
90 self.len == 0
91 }
92
93 #[inline]
94 pub fn as_ptr(&self) -> *const f32 {
96 self.data.as_ptr()
97 }
98
99 #[inline]
100 pub fn as_mut_ptr(&mut self) -> *mut f32 {
102 self.data.as_ptr()
103 }
104
105 #[inline]
106 pub fn as_slice(&self) -> &[f32] {
108 unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.len) }
109 }
110
111 #[inline]
112 pub fn as_mut_slice(&mut self) -> &mut [f32] {
114 unsafe { std::slice::from_raw_parts_mut(self.data.as_ptr(), self.len) }
115 }
116
117 #[inline]
119 pub fn zero(&mut self) {
120 unsafe {
121 std::ptr::write_bytes(self.data.as_ptr(), 0, self.len);
122 }
123 }
124
125 #[inline]
127 pub fn copy_from(&mut self, other: &Tensor1D) {
128 debug_assert_eq!(self.len, other.len);
129 self.as_mut_slice().copy_from_slice(other.as_slice());
130 }
131
132 #[inline]
134 pub fn copy_from_slice(&mut self, slice: &[f32]) {
135 debug_assert_eq!(self.len, slice.len());
136 self.as_mut_slice().copy_from_slice(slice);
137 }
138}
139
140impl Clone for Tensor1D {
141 fn clone(&self) -> Self {
142 let mut new = Self::zeros(self.len);
143 new.as_mut_slice().copy_from_slice(self.as_slice());
144 new
145 }
146}
147
148impl Drop for Tensor1D {
149 fn drop(&mut self) {
150 unsafe {
151 dealloc_f32_buffer(self.data, self.len);
152 }
153 }
154}
155
156unsafe impl Send for Tensor1D {}
158unsafe impl Sync for Tensor1D {}
159
160impl Index<usize> for Tensor1D {
161 type Output = f32;
162
163 #[inline]
164 fn index(&self, i: usize) -> &f32 {
165 debug_assert!(i < self.len);
166 unsafe { &*self.data.as_ptr().add(i) }
167 }
168}
169
170impl IndexMut<usize> for Tensor1D {
171 #[inline]
172 fn index_mut(&mut self, i: usize) -> &mut f32 {
173 debug_assert!(i < self.len);
174 unsafe { &mut *self.data.as_ptr().add(i) }
175 }
176}
177
178#[repr(C)]
180pub struct Tensor2D {
181 data: NonNull<f32>,
182 rows: usize,
183 cols: usize,
184 stride: usize, }
186
187impl Tensor2D {
188 pub fn zeros(rows: usize, cols: usize) -> Self {
190 let stride = padded_stride(cols);
192 let total = rows
193 .checked_mul(stride)
194 .expect("tensor allocation overflow");
195
196 Self {
197 data: alloc_f32_buffer(total),
198 rows,
199 cols,
200 stride,
201 }
202 }
203
204 pub fn from_vec(v: Vec<f32>, rows: usize, cols: usize) -> Self {
206 assert_eq!(v.len(), rows * cols);
207 let mut t = Self::zeros(rows, cols);
208
209 for r in 0..rows {
211 let src_start = r * cols;
212 let src_end = src_start + cols;
213 t.row_mut(r).copy_from_slice(&v[src_start..src_end]);
214 }
215 t
216 }
217
218 #[inline]
219 pub fn rows(&self) -> usize {
221 self.rows
222 }
223
224 #[inline]
225 pub fn cols(&self) -> usize {
227 self.cols
228 }
229
230 #[inline]
231 pub fn stride(&self) -> usize {
233 self.stride
234 }
235
236 #[inline]
237 pub fn as_ptr(&self) -> *const f32 {
239 self.data.as_ptr()
240 }
241
242 #[inline]
243 pub fn as_mut_ptr(&mut self) -> *mut f32 {
245 self.data.as_ptr()
246 }
247
248 #[inline]
250 pub fn row(&self, r: usize) -> &[f32] {
251 debug_assert!(r < self.rows);
252 unsafe {
253 let ptr = self.data.as_ptr().add(r * self.stride);
254 std::slice::from_raw_parts(ptr, self.cols)
255 }
256 }
257
258 #[inline]
260 pub fn row_mut(&mut self, r: usize) -> &mut [f32] {
261 debug_assert!(r < self.rows);
262 unsafe {
263 let ptr = self.data.as_ptr().add(r * self.stride);
264 std::slice::from_raw_parts_mut(ptr, self.cols)
265 }
266 }
267
268 #[inline]
270 pub fn row_ptr(&self, r: usize) -> *const f32 {
271 debug_assert!(r < self.rows);
272 unsafe { self.data.as_ptr().add(r * self.stride) }
273 }
274
275 #[inline]
277 pub fn row_ptr_mut(&mut self, r: usize) -> *mut f32 {
278 debug_assert!(r < self.rows);
279 unsafe { self.data.as_ptr().add(r * self.stride) }
280 }
281
282 pub fn zero(&mut self) {
284 let total = self
285 .rows
286 .checked_mul(self.stride)
287 .expect("tensor allocation overflow");
288 unsafe {
289 std::ptr::write_bytes(self.data.as_ptr(), 0, total);
290 }
291 }
292}
293
294impl Clone for Tensor2D {
295 fn clone(&self) -> Self {
296 let total = self
297 .rows
298 .checked_mul(self.stride)
299 .expect("tensor allocation overflow");
300 let data = alloc_f32_buffer(total);
301
302 unsafe {
303 std::ptr::copy_nonoverlapping(self.data.as_ptr(), data.as_ptr(), total);
304 }
305
306 Self {
307 data,
308 rows: self.rows,
309 cols: self.cols,
310 stride: self.stride,
311 }
312 }
313}
314
315impl Drop for Tensor2D {
316 fn drop(&mut self) {
317 let total = self
318 .rows
319 .checked_mul(self.stride)
320 .expect("tensor allocation overflow");
321 unsafe {
322 dealloc_f32_buffer(self.data, total);
323 }
324 }
325}
326
327unsafe impl Send for Tensor2D {}
329unsafe impl Sync for Tensor2D {}
330
331#[derive(Clone, Copy)]
333pub struct TensorView1D<'a> {
334 data: &'a [f32],
335}
336
337impl<'a> TensorView1D<'a> {
338 #[inline]
339 pub fn new(data: &'a [f32]) -> Self {
341 Self { data }
342 }
343
344 #[inline]
345 pub fn len(&self) -> usize {
347 self.data.len()
348 }
349
350 #[inline]
351 pub fn is_empty(&self) -> bool {
353 self.data.is_empty()
354 }
355
356 #[inline]
357 pub fn as_ptr(&self) -> *const f32 {
359 self.data.as_ptr()
360 }
361
362 #[inline]
363 pub fn as_slice(&self) -> &[f32] {
365 self.data
366 }
367}
368
369impl<'a> Index<usize> for TensorView1D<'a> {
370 type Output = f32;
371
372 #[inline]
373 fn index(&self, i: usize) -> &f32 {
374 &self.data[i]
375 }
376}
377
378#[derive(Clone, Copy)]
380pub struct TensorView2D<'a> {
381 data: &'a [f32],
382 rows: usize,
383 cols: usize,
384}
385
386impl<'a> TensorView2D<'a> {
387 #[inline]
388 pub fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
390 debug_assert_eq!(data.len(), rows * cols);
391 Self { data, rows, cols }
392 }
393
394 #[inline]
395 pub fn rows(&self) -> usize {
397 self.rows
398 }
399
400 #[inline]
401 pub fn cols(&self) -> usize {
403 self.cols
404 }
405
406 #[inline]
407 pub fn as_ptr(&self) -> *const f32 {
409 self.data.as_ptr()
410 }
411
412 #[inline]
413 pub fn row(&self, r: usize) -> &[f32] {
415 debug_assert!(r < self.rows);
416 let start = r * self.cols;
417 &self.data[start..start + self.cols]
418 }
419
420 #[inline]
421 pub fn row_ptr(&self, r: usize) -> *const f32 {
423 debug_assert!(r < self.rows);
424 unsafe { self.data.as_ptr().add(r * self.cols) }
425 }
426
427 pub fn t(&self) -> TransposedView2D<'a> {
431 TransposedView2D {
432 data: self.data,
433 rows: self.cols, cols: self.rows, orig_cols: self.cols,
436 }
437 }
438}
439
440#[derive(Clone, Copy)]
442pub struct TransposedView2D<'a> {
443 data: &'a [f32],
444 rows: usize,
445 cols: usize,
446 orig_cols: usize,
447}
448
449impl<'a> TransposedView2D<'a> {
450 #[inline]
451 pub fn rows(&self) -> usize {
452 self.rows
453 }
454
455 #[inline]
456 pub fn cols(&self) -> usize {
457 self.cols
458 }
459
460 #[inline]
462 pub fn get(&self, r: usize, c: usize) -> f32 {
463 self.data[c * self.orig_cols + r]
465 }
466
467 #[inline]
469 pub fn orig_row(&self, r: usize) -> &[f32] {
470 let start = r * self.orig_cols;
471 &self.data[start..start + self.orig_cols]
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn zero_len_tensor1d_uses_aligned_non_allocating_sentinel() {
481 let mut t = Tensor1D::zeros(0);
482 assert_eq!(t.len(), 0);
483 assert!(t.is_empty());
484 assert!(t.as_slice().is_empty());
485 assert!(t.as_mut_slice().is_empty());
486 assert_eq!((t.as_ptr() as usize) % ALIGNMENT, 0);
487 t.zero();
488 }
489
490 #[test]
491 fn zero_sized_tensor2d_is_safe() {
492 let mut t = Tensor2D::zeros(3, 0);
493 assert_eq!(t.rows(), 3);
494 assert_eq!(t.cols(), 0);
495 assert_eq!(t.stride(), 0);
496 assert_eq!((t.as_ptr() as usize) % ALIGNMENT, 0);
497 for row in 0..t.rows() {
498 assert!(t.row(row).is_empty());
499 assert!(t.row_mut(row).is_empty());
500 }
501 t.zero();
502 }
503}