1use std::alloc::{alloc_zeroed, dealloc, Layout};
9use std::ops::{Index, IndexMut};
10use std::ptr::NonNull;
11
12const ALIGNMENT: usize = 32;
14
15#[repr(C)]
17pub struct Tensor1D {
18 data: NonNull<f32>,
19 len: usize,
20}
21
22impl Tensor1D {
23 pub fn zeros(len: usize) -> Self {
25 let layout = Layout::from_size_align(len * 4, ALIGNMENT).expect("Invalid layout");
26
27 let ptr = unsafe { alloc_zeroed(layout) as *mut f32 };
28 let data = NonNull::new(ptr).expect("Allocation failed");
29
30 Self { data, len }
31 }
32
33 pub fn from_vec(v: Vec<f32>) -> Self {
35 let mut t = Self::zeros(v.len());
36 t.as_mut_slice().copy_from_slice(&v);
37 t
38 }
39
40 #[inline]
41 pub fn len(&self) -> usize {
42 self.len
43 }
44
45 #[inline]
46 pub fn is_empty(&self) -> bool {
47 self.len == 0
48 }
49
50 #[inline]
51 pub fn as_ptr(&self) -> *const f32 {
52 self.data.as_ptr()
53 }
54
55 #[inline]
56 pub fn as_mut_ptr(&mut self) -> *mut f32 {
57 self.data.as_ptr()
58 }
59
60 #[inline]
61 pub fn as_slice(&self) -> &[f32] {
62 unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.len) }
63 }
64
65 #[inline]
66 pub fn as_mut_slice(&mut self) -> &mut [f32] {
67 unsafe { std::slice::from_raw_parts_mut(self.data.as_ptr(), self.len) }
68 }
69
70 #[inline]
72 pub fn zero(&mut self) {
73 unsafe {
74 std::ptr::write_bytes(self.data.as_ptr(), 0, self.len);
75 }
76 }
77
78 #[inline]
80 pub fn copy_from(&mut self, other: &Tensor1D) {
81 debug_assert_eq!(self.len, other.len);
82 self.as_mut_slice().copy_from_slice(other.as_slice());
83 }
84
85 #[inline]
87 pub fn copy_from_slice(&mut self, slice: &[f32]) {
88 debug_assert_eq!(self.len, slice.len());
89 self.as_mut_slice().copy_from_slice(slice);
90 }
91}
92
93impl Clone for Tensor1D {
94 fn clone(&self) -> Self {
95 let mut new = Self::zeros(self.len);
96 new.as_mut_slice().copy_from_slice(self.as_slice());
97 new
98 }
99}
100
101impl Drop for Tensor1D {
102 fn drop(&mut self) {
103 let layout = Layout::from_size_align(self.len * 4, ALIGNMENT).expect("Invalid layout");
104 unsafe {
105 dealloc(self.data.as_ptr() as *mut u8, layout);
106 }
107 }
108}
109
110unsafe impl Send for Tensor1D {}
112unsafe impl Sync for Tensor1D {}
113
114impl Index<usize> for Tensor1D {
115 type Output = f32;
116
117 #[inline]
118 fn index(&self, i: usize) -> &f32 {
119 debug_assert!(i < self.len);
120 unsafe { &*self.data.as_ptr().add(i) }
121 }
122}
123
124impl IndexMut<usize> for Tensor1D {
125 #[inline]
126 fn index_mut(&mut self, i: usize) -> &mut f32 {
127 debug_assert!(i < self.len);
128 unsafe { &mut *self.data.as_ptr().add(i) }
129 }
130}
131
132#[repr(C)]
134pub struct Tensor2D {
135 data: NonNull<f32>,
136 rows: usize,
137 cols: usize,
138 stride: usize, }
140
141impl Tensor2D {
142 pub fn zeros(rows: usize, cols: usize) -> Self {
144 let stride = (cols + 7) & !7;
146 let total = rows * stride;
147
148 let layout = Layout::from_size_align(total * 4, ALIGNMENT).expect("Invalid layout");
149
150 let ptr = unsafe { alloc_zeroed(layout) as *mut f32 };
151 let data = NonNull::new(ptr).expect("Allocation failed");
152
153 Self {
154 data,
155 rows,
156 cols,
157 stride,
158 }
159 }
160
161 pub fn from_vec(v: Vec<f32>, rows: usize, cols: usize) -> Self {
163 assert_eq!(v.len(), rows * cols);
164 let mut t = Self::zeros(rows, cols);
165
166 for r in 0..rows {
168 let src_start = r * cols;
169 let src_end = src_start + cols;
170 t.row_mut(r).copy_from_slice(&v[src_start..src_end]);
171 }
172 t
173 }
174
175 #[inline]
176 pub fn rows(&self) -> usize {
177 self.rows
178 }
179
180 #[inline]
181 pub fn cols(&self) -> usize {
182 self.cols
183 }
184
185 #[inline]
186 pub fn stride(&self) -> usize {
187 self.stride
188 }
189
190 #[inline]
191 pub fn as_ptr(&self) -> *const f32 {
192 self.data.as_ptr()
193 }
194
195 #[inline]
196 pub fn as_mut_ptr(&mut self) -> *mut f32 {
197 self.data.as_ptr()
198 }
199
200 #[inline]
202 pub fn row(&self, r: usize) -> &[f32] {
203 debug_assert!(r < self.rows);
204 unsafe {
205 let ptr = self.data.as_ptr().add(r * self.stride);
206 std::slice::from_raw_parts(ptr, self.cols)
207 }
208 }
209
210 #[inline]
212 pub fn row_mut(&mut self, r: usize) -> &mut [f32] {
213 debug_assert!(r < self.rows);
214 unsafe {
215 let ptr = self.data.as_ptr().add(r * self.stride);
216 std::slice::from_raw_parts_mut(ptr, self.cols)
217 }
218 }
219
220 #[inline]
222 pub fn row_ptr(&self, r: usize) -> *const f32 {
223 debug_assert!(r < self.rows);
224 unsafe { self.data.as_ptr().add(r * self.stride) }
225 }
226
227 #[inline]
229 pub fn row_ptr_mut(&mut self, r: usize) -> *mut f32 {
230 debug_assert!(r < self.rows);
231 unsafe { self.data.as_ptr().add(r * self.stride) }
232 }
233
234 pub fn zero(&mut self) {
236 let total = self.rows * self.stride;
237 unsafe {
238 std::ptr::write_bytes(self.data.as_ptr(), 0, total);
239 }
240 }
241}
242
243impl Clone for Tensor2D {
244 fn clone(&self) -> Self {
245 let total = self.rows * self.stride;
246 let layout = Layout::from_size_align(total * 4, ALIGNMENT).expect("Invalid layout");
247
248 let ptr = unsafe { alloc_zeroed(layout) as *mut f32 };
249 let data = NonNull::new(ptr).expect("Allocation failed");
250
251 unsafe {
252 std::ptr::copy_nonoverlapping(self.data.as_ptr(), ptr, total);
253 }
254
255 Self {
256 data,
257 rows: self.rows,
258 cols: self.cols,
259 stride: self.stride,
260 }
261 }
262}
263
264impl Drop for Tensor2D {
265 fn drop(&mut self) {
266 let total = self.rows * self.stride;
267 let layout = Layout::from_size_align(total * 4, ALIGNMENT).expect("Invalid layout");
268 unsafe {
269 dealloc(self.data.as_ptr() as *mut u8, layout);
270 }
271 }
272}
273
274unsafe impl Send for Tensor2D {}
276unsafe impl Sync for Tensor2D {}
277
278#[derive(Clone, Copy)]
280pub struct TensorView1D<'a> {
281 data: &'a [f32],
282}
283
284impl<'a> TensorView1D<'a> {
285 #[inline]
286 pub fn new(data: &'a [f32]) -> Self {
287 Self { data }
288 }
289
290 #[inline]
291 pub fn len(&self) -> usize {
292 self.data.len()
293 }
294
295 #[inline]
296 pub fn is_empty(&self) -> bool {
297 self.data.is_empty()
298 }
299
300 #[inline]
301 pub fn as_ptr(&self) -> *const f32 {
302 self.data.as_ptr()
303 }
304
305 #[inline]
306 pub fn as_slice(&self) -> &[f32] {
307 self.data
308 }
309}
310
311impl<'a> Index<usize> for TensorView1D<'a> {
312 type Output = f32;
313
314 #[inline]
315 fn index(&self, i: usize) -> &f32 {
316 &self.data[i]
317 }
318}
319
320#[derive(Clone, Copy)]
322pub struct TensorView2D<'a> {
323 data: &'a [f32],
324 rows: usize,
325 cols: usize,
326}
327
328impl<'a> TensorView2D<'a> {
329 #[inline]
330 pub fn new(data: &'a [f32], rows: usize, cols: usize) -> Self {
331 debug_assert_eq!(data.len(), rows * cols);
332 Self { data, rows, cols }
333 }
334
335 #[inline]
336 pub fn rows(&self) -> usize {
337 self.rows
338 }
339
340 #[inline]
341 pub fn cols(&self) -> usize {
342 self.cols
343 }
344
345 #[inline]
346 pub fn as_ptr(&self) -> *const f32 {
347 self.data.as_ptr()
348 }
349
350 #[inline]
351 pub fn row(&self, r: usize) -> &[f32] {
352 debug_assert!(r < self.rows);
353 let start = r * self.cols;
354 &self.data[start..start + self.cols]
355 }
356
357 #[inline]
358 pub fn row_ptr(&self, r: usize) -> *const f32 {
359 debug_assert!(r < self.rows);
360 unsafe { self.data.as_ptr().add(r * self.cols) }
361 }
362
363 pub fn t(&self) -> TransposedView2D<'a> {
367 TransposedView2D {
368 data: self.data,
369 rows: self.cols, cols: self.rows, orig_cols: self.cols,
372 }
373 }
374}
375
376#[derive(Clone, Copy)]
378pub struct TransposedView2D<'a> {
379 data: &'a [f32],
380 rows: usize,
381 cols: usize,
382 orig_cols: usize,
383}
384
385impl<'a> TransposedView2D<'a> {
386 #[inline]
387 pub fn rows(&self) -> usize {
388 self.rows
389 }
390
391 #[inline]
392 pub fn cols(&self) -> usize {
393 self.cols
394 }
395
396 #[inline]
398 pub fn get(&self, r: usize, c: usize) -> f32 {
399 self.data[c * self.orig_cols + r]
401 }
402
403 #[inline]
405 pub fn orig_row(&self, r: usize) -> &[f32] {
406 let start = r * self.orig_cols;
407 &self.data[start..start + self.orig_cols]
408 }
409}