rwkvzip/rwkv7/
weights.rs

1//! Native safetensors weight loading for RWKV7.
2//!
3//! This module provides a zero-dependency safetensors parser optimized for loading
4//! RWKV7 model weights. The implementation directly parses the safetensors JSON
5//! header and efficiently loads FP32 tensor data.
6//!
7//! # File Format
8//!
9//! Safetensors files have a simple structure:
10//! - 8 bytes: header length (little-endian u64)
11//! - N bytes: JSON header containing tensor metadata
12//! - Remaining bytes: contiguous tensor data
13//!
14//! The JSON header maps tensor names to their dtype, shape, and data_offsets.
15
16use ahash::{HashMap, HashMapExt};
17use std::fs::File;
18use std::io::{Read, Seek, SeekFrom};
19use std::path::Path;
20
21use anyhow::{bail, Context, Result};
22
23use super::tensor::{Tensor1D, Tensor2D};
24
25// =============================================================================
26// Tensor Metadata
27// =============================================================================
28
29/// Parsed metadata for a single tensor from the safetensors header.
30#[derive(Debug, Clone)]
31struct TensorMeta {
32    /// Shape dimensions (e.g., [256, 256] for a 256x256 matrix).
33    shape: Vec<usize>,
34    /// Byte offset where tensor data begins (relative to data section start).
35    offset_start: usize,
36    /// Byte offset where tensor data ends (exclusive).
37    offset_end: usize,
38}
39
40// =============================================================================
41// Weight Storage
42// =============================================================================
43
44/// Loaded tensor with data and shape information.
45pub struct WeightTensor {
46    /// Raw FP32 data in row-major order.
47    data: Vec<f32>,
48    /// Shape dimensions.
49    shape: Vec<usize>,
50}
51
52impl WeightTensor {
53    /// Get the raw data slice.
54    #[inline]
55    pub fn data(&self) -> &[f32] {
56        &self.data
57    }
58
59    /// Get the shape dimensions.
60    #[inline]
61    pub fn shape(&self) -> &[usize] {
62        &self.shape
63    }
64
65    /// Get total number of elements.
66    #[inline]
67    pub fn numel(&self) -> usize {
68        self.data.len()
69    }
70
71    /// View as 1D slice.
72    #[inline]
73    pub fn as_1d(&self) -> &[f32] {
74        &self.data
75    }
76
77    /// Iterate over rows for 2D access.
78    pub fn as_2d(&self, rows: usize, cols: usize) -> impl Iterator<Item = &[f32]> {
79        debug_assert_eq!(rows * cols, self.data.len());
80        (0..rows).map(move |r| &self.data[r * cols..(r + 1) * cols])
81    }
82}
83
84// =============================================================================
85// Weights Container
86// =============================================================================
87
88/// Container for all loaded RWKV7 model weights.
89pub struct Weights {
90    pub(crate) tensors: HashMap<String, WeightTensor>,
91}
92
93impl Weights {
94    /// Load weights from a safetensors file.
95    ///
96    /// This function:
97    /// 1. Reads and parses the JSON header to extract tensor metadata
98    /// 2. Loads each tensor's FP32 data into memory
99    /// 3. Returns a Weights container for efficient tensor lookup
100    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
101        let path = path.as_ref();
102        let mut file = File::open(path)
103            .with_context(|| format!("Failed to open weights file: {}", path.display()))?;
104
105        // Read header length (8-byte little-endian u64)
106        let mut header_len_bytes = [0u8; 8];
107        file.read_exact(&mut header_len_bytes)?;
108        let header_len = u64::from_le_bytes(header_len_bytes) as usize;
109
110        // Read JSON header
111        let mut header_bytes = vec![0u8; header_len];
112        file.read_exact(&mut header_bytes)?;
113        let header_str =
114            std::str::from_utf8(&header_bytes).context("Invalid UTF-8 in safetensors header")?;
115
116        // Parse tensor metadata from JSON header
117        let metas = parse_safetensors_header(header_str)?;
118
119        // Data section starts after the 8-byte length + header
120        let data_offset = 8 + header_len;
121
122        // Load all tensors
123        let mut tensors = HashMap::with_capacity(metas.len());
124        for (name, meta) in metas {
125            // Seek to tensor data
126            file.seek(SeekFrom::Start((data_offset + meta.offset_start) as u64))?;
127
128            let byte_len = meta.offset_end - meta.offset_start;
129            let mut raw_bytes = vec![0u8; byte_len];
130            file.read_exact(&mut raw_bytes)?;
131
132            // Convert bytes to f32 (little-endian)
133            let data = bytes_to_f32(&raw_bytes);
134
135            tensors.insert(
136                name,
137                WeightTensor {
138                    data,
139                    shape: meta.shape,
140                },
141            );
142        }
143
144        Ok(Self { tensors })
145    }
146
147    /// Get a tensor by name, returning None if not found.
148    #[inline]
149    pub fn get(&self, name: &str) -> Option<&WeightTensor> {
150        self.tensors.get(name)
151    }
152
153    /// Get a tensor by name, or return an error if not found.
154    pub fn require(&self, name: &str) -> Result<&WeightTensor> {
155        self.tensors
156            .get(name)
157            .with_context(|| format!("Missing required tensor: {}", name))
158    }
159
160    /// Get a tensor as a 1D aligned tensor.
161    pub fn get_1d(&self, name: &str) -> Result<Tensor1D> {
162        let t = self.require(name)?;
163        Ok(Tensor1D::from_vec(t.data.clone()))
164    }
165
166    /// Get a tensor as a 2D aligned tensor.
167    pub fn get_2d(&self, name: &str) -> Result<Tensor2D> {
168        let t = self.require(name)?;
169        match t.shape.len() {
170            1 => Ok(Tensor2D::from_vec(t.data.clone(), 1, t.shape[0])),
171            2 => Ok(Tensor2D::from_vec(t.data.clone(), t.shape[0], t.shape[1])),
172            _ => bail!(
173                "Expected 1D or 2D tensor for '{}', got shape {:?}",
174                name,
175                t.shape
176            ),
177        }
178    }
179
180    /// Iterate over all tensor names.
181    pub fn tensor_names(&self) -> impl Iterator<Item = &str> {
182        self.tensors.keys().map(|s| s.as_str())
183    }
184
185    /// Print summary of all loaded tensors (for debugging).
186    pub fn print_summary(&self) {
187        let mut names: Vec<_> = self.tensors.keys().collect();
188        names.sort();
189        for name in names {
190            let t = &self.tensors[name];
191            println!("  {} {:?} ({} params)", name, t.shape, t.numel());
192        }
193    }
194}
195
196// =============================================================================
197// JSON Header Parser
198// =============================================================================
199
200/// Parse the safetensors JSON header to extract tensor metadata.
201///
202/// This is a minimal, hand-written parser that only extracts the fields we need:
203/// - shape: array of dimension sizes
204/// - data_offsets: [start, end] byte offsets
205///
206/// We skip dtype since we assume all tensors are FP32.
207fn parse_safetensors_header(json: &str) -> Result<HashMap<String, TensorMeta>> {
208    let bytes = json.as_bytes();
209    let mut pos = 0;
210    let mut metas = HashMap::new();
211
212    // Skip whitespace and opening brace
213    skip_whitespace(bytes, &mut pos);
214    expect_char(bytes, &mut pos, b'{')?;
215
216    loop {
217        skip_whitespace(bytes, &mut pos);
218
219        // Check for end of object
220        if pos < bytes.len() && bytes[pos] == b'}' {
221            break;
222        }
223
224        // Skip comma between entries
225        if pos < bytes.len() && bytes[pos] == b',' {
226            pos += 1;
227            skip_whitespace(bytes, &mut pos);
228        }
229
230        // Parse tensor name
231        let name = parse_string(bytes, &mut pos)?;
232
233        // Skip __metadata__ entries
234        if name == "__metadata__" {
235            skip_whitespace(bytes, &mut pos);
236            expect_char(bytes, &mut pos, b':')?;
237            skip_json_value(bytes, &mut pos)?;
238            continue;
239        }
240
241        skip_whitespace(bytes, &mut pos);
242        expect_char(bytes, &mut pos, b':')?;
243        skip_whitespace(bytes, &mut pos);
244
245        // Parse tensor info object
246        let meta = parse_tensor_info(bytes, &mut pos)?;
247        metas.insert(name, meta);
248    }
249
250    Ok(metas)
251}
252
253/// Parse a tensor info object: { "dtype": "...", "shape": [...], "data_offsets": [...] }
254fn parse_tensor_info(bytes: &[u8], pos: &mut usize) -> Result<TensorMeta> {
255    expect_char(bytes, pos, b'{')?;
256
257    let mut shape: Option<Vec<usize>> = None;
258    let mut offset_start: Option<usize> = None;
259    let mut offset_end: Option<usize> = None;
260
261    loop {
262        skip_whitespace(bytes, pos);
263
264        if *pos < bytes.len() && bytes[*pos] == b'}' {
265            *pos += 1;
266            break;
267        }
268
269        if *pos < bytes.len() && bytes[*pos] == b',' {
270            *pos += 1;
271            skip_whitespace(bytes, pos);
272        }
273
274        let key = parse_string(bytes, pos)?;
275        skip_whitespace(bytes, pos);
276        expect_char(bytes, pos, b':')?;
277        skip_whitespace(bytes, pos);
278
279        match key.as_str() {
280            "shape" => {
281                shape = Some(parse_int_array(bytes, pos)?);
282            }
283            "data_offsets" => {
284                let offsets = parse_int_array(bytes, pos)?;
285                if offsets.len() >= 2 {
286                    offset_start = Some(offsets[0]);
287                    offset_end = Some(offsets[1]);
288                }
289            }
290            _ => {
291                // Skip dtype and any other fields
292                skip_json_value(bytes, pos)?;
293            }
294        }
295    }
296
297    Ok(TensorMeta {
298        shape: shape.unwrap_or_default(),
299        offset_start: offset_start.unwrap_or(0),
300        offset_end: offset_end.unwrap_or(0),
301    })
302}
303
304/// Parse a JSON string (expects opening quote at current position).
305fn parse_string(bytes: &[u8], pos: &mut usize) -> Result<String> {
306    expect_char(bytes, pos, b'"')?;
307
308    let start = *pos;
309    while *pos < bytes.len() && bytes[*pos] != b'"' {
310        if bytes[*pos] == b'\\' {
311            *pos += 1; // Skip escape character
312        }
313        *pos += 1;
314    }
315    let end = *pos;
316
317    expect_char(bytes, pos, b'"')?;
318
319    String::from_utf8(bytes[start..end].to_vec()).context("Invalid UTF-8 in JSON string")
320}
321
322/// Parse a JSON array of integers.
323fn parse_int_array(bytes: &[u8], pos: &mut usize) -> Result<Vec<usize>> {
324    expect_char(bytes, pos, b'[')?;
325
326    let mut result = Vec::new();
327
328    loop {
329        skip_whitespace(bytes, pos);
330
331        if *pos < bytes.len() && bytes[*pos] == b']' {
332            *pos += 1;
333            break;
334        }
335
336        if *pos < bytes.len() && bytes[*pos] == b',' {
337            *pos += 1;
338            skip_whitespace(bytes, pos);
339        }
340
341        result.push(parse_int(bytes, pos)?);
342    }
343
344    Ok(result)
345}
346
347/// Parse a single integer.
348fn parse_int(bytes: &[u8], pos: &mut usize) -> Result<usize> {
349    let start = *pos;
350    while *pos < bytes.len() && bytes[*pos].is_ascii_digit() {
351        *pos += 1;
352    }
353
354    if start == *pos {
355        bail!("Expected integer at position {}", *pos);
356    }
357
358    let s = std::str::from_utf8(&bytes[start..*pos])?;
359    s.parse().context("Failed to parse integer")
360}
361
362/// Skip a JSON value (string, number, object, array, boolean, null).
363fn skip_json_value(bytes: &[u8], pos: &mut usize) -> Result<()> {
364    skip_whitespace(bytes, pos);
365
366    if *pos >= bytes.len() {
367        return Ok(());
368    }
369
370    match bytes[*pos] {
371        b'"' => {
372            // String
373            *pos += 1;
374            while *pos < bytes.len() && bytes[*pos] != b'"' {
375                if bytes[*pos] == b'\\' {
376                    *pos += 1;
377                }
378                *pos += 1;
379            }
380            *pos += 1; // Skip closing quote
381        }
382        b'{' => {
383            // Object
384            let mut depth = 1;
385            *pos += 1;
386            while *pos < bytes.len() && depth > 0 {
387                match bytes[*pos] {
388                    b'{' => depth += 1,
389                    b'}' => depth -= 1,
390                    b'"' => {
391                        *pos += 1;
392                        while *pos < bytes.len() && bytes[*pos] != b'"' {
393                            if bytes[*pos] == b'\\' {
394                                *pos += 1;
395                            }
396                            *pos += 1;
397                        }
398                    }
399                    _ => {}
400                }
401                *pos += 1;
402            }
403        }
404        b'[' => {
405            // Array
406            let mut depth = 1;
407            *pos += 1;
408            while *pos < bytes.len() && depth > 0 {
409                match bytes[*pos] {
410                    b'[' => depth += 1,
411                    b']' => depth -= 1,
412                    b'"' => {
413                        *pos += 1;
414                        while *pos < bytes.len() && bytes[*pos] != b'"' {
415                            if bytes[*pos] == b'\\' {
416                                *pos += 1;
417                            }
418                            *pos += 1;
419                        }
420                    }
421                    _ => {}
422                }
423                *pos += 1;
424            }
425        }
426        _ => {
427            // Number, boolean, or null
428            while *pos < bytes.len() && !matches!(bytes[*pos], b',' | b'}' | b']') {
429                *pos += 1;
430            }
431        }
432    }
433
434    Ok(())
435}
436
437/// Skip whitespace characters.
438#[inline]
439fn skip_whitespace(bytes: &[u8], pos: &mut usize) {
440    while *pos < bytes.len() && bytes[*pos].is_ascii_whitespace() {
441        *pos += 1;
442    }
443}
444
445/// Expect a specific character at the current position.
446#[inline]
447fn expect_char(bytes: &[u8], pos: &mut usize, expected: u8) -> Result<()> {
448    if *pos >= bytes.len() || bytes[*pos] != expected {
449        bail!(
450            "Expected '{}' at position {}, found '{}'",
451            expected as char,
452            *pos,
453            bytes.get(*pos).map(|&b| b as char).unwrap_or('\0')
454        );
455    }
456    *pos += 1;
457    Ok(())
458}
459
460// =============================================================================
461// Byte Conversion
462// =============================================================================
463
464/// Convert a byte slice to FP32 values (little-endian).
465///
466/// This is a hot path during model loading, so we process 4 floats at a time
467/// to help the compiler vectorize.
468#[inline]
469fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
470    let num_floats = bytes.len() / 4;
471    let mut result = vec![0.0f32; num_floats];
472
473    // Process in chunks of 4 floats (16 bytes) for better vectorization
474    let chunks = num_floats / 4;
475    for i in 0..chunks {
476        let base = i * 16;
477        result[i * 4] = f32::from_le_bytes([
478            bytes[base],
479            bytes[base + 1],
480            bytes[base + 2],
481            bytes[base + 3],
482        ]);
483        result[i * 4 + 1] = f32::from_le_bytes([
484            bytes[base + 4],
485            bytes[base + 5],
486            bytes[base + 6],
487            bytes[base + 7],
488        ]);
489        result[i * 4 + 2] = f32::from_le_bytes([
490            bytes[base + 8],
491            bytes[base + 9],
492            bytes[base + 10],
493            bytes[base + 11],
494        ]);
495        result[i * 4 + 3] = f32::from_le_bytes([
496            bytes[base + 12],
497            bytes[base + 13],
498            bytes[base + 14],
499            bytes[base + 15],
500        ]);
501    }
502
503    // Handle remaining floats
504    for i in (chunks * 4)..num_floats {
505        let base = i * 4;
506        result[i] = f32::from_le_bytes([
507            bytes[base],
508            bytes[base + 1],
509            bytes[base + 2],
510            bytes[base + 3],
511        ]);
512    }
513
514    result
515}
516
517// =============================================================================
518// Tests
519// =============================================================================
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[test]
526    fn test_load_weights() {
527        // This test requires a model file - skip if not available
528        let path =
529            std::env::var("RWKV_MODEL").unwrap_or_else(|_| "rwkv-10m.safetensors".to_string());
530
531        if !Path::new(&path).exists() {
532            eprintln!("Skipping test: model file not found at {}", path);
533            return;
534        }
535
536        let weights = Weights::load(&path).unwrap();
537
538        // Verify expected tensors exist
539        assert!(weights.get("model.embeddings.weight").is_some());
540        assert!(weights.get("lm_head.weight").is_some());
541
542        // Verify embedding shape (256 vocab, 256 hidden)
543        let emb = weights.get("model.embeddings.weight").unwrap();
544        assert_eq!(emb.shape.len(), 2);
545        assert_eq!(emb.shape[0], 256);
546        assert_eq!(emb.shape[1], 256);
547    }
548
549    #[test]
550    fn test_bytes_to_f32() {
551        // Test basic conversion
552        let bytes = [0x00, 0x00, 0x80, 0x3F]; // 1.0 in little-endian
553        let result = bytes_to_f32(&bytes);
554        assert_eq!(result.len(), 1);
555        assert!((result[0] - 1.0).abs() < 1e-6);
556
557        // Test multiple values
558        let bytes = [
559            0x00, 0x00, 0x80, 0x3F, // 1.0
560            0x00, 0x00, 0x00, 0x40, // 2.0
561        ];
562        let result = bytes_to_f32(&bytes);
563        assert_eq!(result.len(), 2);
564        assert!((result[0] - 1.0).abs() < 1e-6);
565        assert!((result[1] - 2.0).abs() < 1e-6);
566    }
567}