1use ahash::{HashMap, HashMapExt};
17use std::fs::File;
18use std::io::{Read, Seek, SeekFrom};
19use std::path::Path;
20
21use anyhow::{bail, Context, Result};
22
23use super::tensor::{Tensor1D, Tensor2D};
24
25#[derive(Debug, Clone)]
31struct TensorMeta {
32 shape: Vec<usize>,
34 offset_start: usize,
36 offset_end: usize,
38}
39
40pub struct WeightTensor {
46 data: Vec<f32>,
48 shape: Vec<usize>,
50}
51
52impl WeightTensor {
53 #[inline]
55 pub fn data(&self) -> &[f32] {
56 &self.data
57 }
58
59 #[inline]
61 pub fn shape(&self) -> &[usize] {
62 &self.shape
63 }
64
65 #[inline]
67 pub fn numel(&self) -> usize {
68 self.data.len()
69 }
70
71 #[inline]
73 pub fn as_1d(&self) -> &[f32] {
74 &self.data
75 }
76
77 pub fn as_2d(&self, rows: usize, cols: usize) -> impl Iterator<Item = &[f32]> {
79 debug_assert_eq!(rows * cols, self.data.len());
80 (0..rows).map(move |r| &self.data[r * cols..(r + 1) * cols])
81 }
82}
83
84pub struct Weights {
90 pub(crate) tensors: HashMap<String, WeightTensor>,
91}
92
93impl Weights {
94 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
101 let path = path.as_ref();
102 let mut file = File::open(path)
103 .with_context(|| format!("Failed to open weights file: {}", path.display()))?;
104
105 let mut header_len_bytes = [0u8; 8];
107 file.read_exact(&mut header_len_bytes)?;
108 let header_len = u64::from_le_bytes(header_len_bytes) as usize;
109
110 let mut header_bytes = vec![0u8; header_len];
112 file.read_exact(&mut header_bytes)?;
113 let header_str =
114 std::str::from_utf8(&header_bytes).context("Invalid UTF-8 in safetensors header")?;
115
116 let metas = parse_safetensors_header(header_str)?;
118
119 let data_offset = 8 + header_len;
121
122 let mut tensors = HashMap::with_capacity(metas.len());
124 for (name, meta) in metas {
125 file.seek(SeekFrom::Start((data_offset + meta.offset_start) as u64))?;
127
128 let byte_len = meta.offset_end - meta.offset_start;
129 let mut raw_bytes = vec![0u8; byte_len];
130 file.read_exact(&mut raw_bytes)?;
131
132 let data = bytes_to_f32(&raw_bytes);
134
135 tensors.insert(
136 name,
137 WeightTensor {
138 data,
139 shape: meta.shape,
140 },
141 );
142 }
143
144 Ok(Self { tensors })
145 }
146
147 #[inline]
149 pub fn get(&self, name: &str) -> Option<&WeightTensor> {
150 self.tensors.get(name)
151 }
152
153 pub fn require(&self, name: &str) -> Result<&WeightTensor> {
155 self.tensors
156 .get(name)
157 .with_context(|| format!("Missing required tensor: {}", name))
158 }
159
160 pub fn get_1d(&self, name: &str) -> Result<Tensor1D> {
162 let t = self.require(name)?;
163 Ok(Tensor1D::from_vec(t.data.clone()))
164 }
165
166 pub fn get_2d(&self, name: &str) -> Result<Tensor2D> {
168 let t = self.require(name)?;
169 match t.shape.len() {
170 1 => Ok(Tensor2D::from_vec(t.data.clone(), 1, t.shape[0])),
171 2 => Ok(Tensor2D::from_vec(t.data.clone(), t.shape[0], t.shape[1])),
172 _ => bail!(
173 "Expected 1D or 2D tensor for '{}', got shape {:?}",
174 name,
175 t.shape
176 ),
177 }
178 }
179
180 pub fn tensor_names(&self) -> impl Iterator<Item = &str> {
182 self.tensors.keys().map(|s| s.as_str())
183 }
184
185 pub fn print_summary(&self) {
187 let mut names: Vec<_> = self.tensors.keys().collect();
188 names.sort();
189 for name in names {
190 let t = &self.tensors[name];
191 println!(" {} {:?} ({} params)", name, t.shape, t.numel());
192 }
193 }
194}
195
196fn parse_safetensors_header(json: &str) -> Result<HashMap<String, TensorMeta>> {
208 let bytes = json.as_bytes();
209 let mut pos = 0;
210 let mut metas = HashMap::new();
211
212 skip_whitespace(bytes, &mut pos);
214 expect_char(bytes, &mut pos, b'{')?;
215
216 loop {
217 skip_whitespace(bytes, &mut pos);
218
219 if pos < bytes.len() && bytes[pos] == b'}' {
221 break;
222 }
223
224 if pos < bytes.len() && bytes[pos] == b',' {
226 pos += 1;
227 skip_whitespace(bytes, &mut pos);
228 }
229
230 let name = parse_string(bytes, &mut pos)?;
232
233 if name == "__metadata__" {
235 skip_whitespace(bytes, &mut pos);
236 expect_char(bytes, &mut pos, b':')?;
237 skip_json_value(bytes, &mut pos)?;
238 continue;
239 }
240
241 skip_whitespace(bytes, &mut pos);
242 expect_char(bytes, &mut pos, b':')?;
243 skip_whitespace(bytes, &mut pos);
244
245 let meta = parse_tensor_info(bytes, &mut pos)?;
247 metas.insert(name, meta);
248 }
249
250 Ok(metas)
251}
252
253fn parse_tensor_info(bytes: &[u8], pos: &mut usize) -> Result<TensorMeta> {
255 expect_char(bytes, pos, b'{')?;
256
257 let mut shape: Option<Vec<usize>> = None;
258 let mut offset_start: Option<usize> = None;
259 let mut offset_end: Option<usize> = None;
260
261 loop {
262 skip_whitespace(bytes, pos);
263
264 if *pos < bytes.len() && bytes[*pos] == b'}' {
265 *pos += 1;
266 break;
267 }
268
269 if *pos < bytes.len() && bytes[*pos] == b',' {
270 *pos += 1;
271 skip_whitespace(bytes, pos);
272 }
273
274 let key = parse_string(bytes, pos)?;
275 skip_whitespace(bytes, pos);
276 expect_char(bytes, pos, b':')?;
277 skip_whitespace(bytes, pos);
278
279 match key.as_str() {
280 "shape" => {
281 shape = Some(parse_int_array(bytes, pos)?);
282 }
283 "data_offsets" => {
284 let offsets = parse_int_array(bytes, pos)?;
285 if offsets.len() >= 2 {
286 offset_start = Some(offsets[0]);
287 offset_end = Some(offsets[1]);
288 }
289 }
290 _ => {
291 skip_json_value(bytes, pos)?;
293 }
294 }
295 }
296
297 Ok(TensorMeta {
298 shape: shape.unwrap_or_default(),
299 offset_start: offset_start.unwrap_or(0),
300 offset_end: offset_end.unwrap_or(0),
301 })
302}
303
304fn parse_string(bytes: &[u8], pos: &mut usize) -> Result<String> {
306 expect_char(bytes, pos, b'"')?;
307
308 let start = *pos;
309 while *pos < bytes.len() && bytes[*pos] != b'"' {
310 if bytes[*pos] == b'\\' {
311 *pos += 1; }
313 *pos += 1;
314 }
315 let end = *pos;
316
317 expect_char(bytes, pos, b'"')?;
318
319 String::from_utf8(bytes[start..end].to_vec()).context("Invalid UTF-8 in JSON string")
320}
321
322fn parse_int_array(bytes: &[u8], pos: &mut usize) -> Result<Vec<usize>> {
324 expect_char(bytes, pos, b'[')?;
325
326 let mut result = Vec::new();
327
328 loop {
329 skip_whitespace(bytes, pos);
330
331 if *pos < bytes.len() && bytes[*pos] == b']' {
332 *pos += 1;
333 break;
334 }
335
336 if *pos < bytes.len() && bytes[*pos] == b',' {
337 *pos += 1;
338 skip_whitespace(bytes, pos);
339 }
340
341 result.push(parse_int(bytes, pos)?);
342 }
343
344 Ok(result)
345}
346
347fn parse_int(bytes: &[u8], pos: &mut usize) -> Result<usize> {
349 let start = *pos;
350 while *pos < bytes.len() && bytes[*pos].is_ascii_digit() {
351 *pos += 1;
352 }
353
354 if start == *pos {
355 bail!("Expected integer at position {}", *pos);
356 }
357
358 let s = std::str::from_utf8(&bytes[start..*pos])?;
359 s.parse().context("Failed to parse integer")
360}
361
362fn skip_json_value(bytes: &[u8], pos: &mut usize) -> Result<()> {
364 skip_whitespace(bytes, pos);
365
366 if *pos >= bytes.len() {
367 return Ok(());
368 }
369
370 match bytes[*pos] {
371 b'"' => {
372 *pos += 1;
374 while *pos < bytes.len() && bytes[*pos] != b'"' {
375 if bytes[*pos] == b'\\' {
376 *pos += 1;
377 }
378 *pos += 1;
379 }
380 *pos += 1; }
382 b'{' => {
383 let mut depth = 1;
385 *pos += 1;
386 while *pos < bytes.len() && depth > 0 {
387 match bytes[*pos] {
388 b'{' => depth += 1,
389 b'}' => depth -= 1,
390 b'"' => {
391 *pos += 1;
392 while *pos < bytes.len() && bytes[*pos] != b'"' {
393 if bytes[*pos] == b'\\' {
394 *pos += 1;
395 }
396 *pos += 1;
397 }
398 }
399 _ => {}
400 }
401 *pos += 1;
402 }
403 }
404 b'[' => {
405 let mut depth = 1;
407 *pos += 1;
408 while *pos < bytes.len() && depth > 0 {
409 match bytes[*pos] {
410 b'[' => depth += 1,
411 b']' => depth -= 1,
412 b'"' => {
413 *pos += 1;
414 while *pos < bytes.len() && bytes[*pos] != b'"' {
415 if bytes[*pos] == b'\\' {
416 *pos += 1;
417 }
418 *pos += 1;
419 }
420 }
421 _ => {}
422 }
423 *pos += 1;
424 }
425 }
426 _ => {
427 while *pos < bytes.len() && !matches!(bytes[*pos], b',' | b'}' | b']') {
429 *pos += 1;
430 }
431 }
432 }
433
434 Ok(())
435}
436
437#[inline]
439fn skip_whitespace(bytes: &[u8], pos: &mut usize) {
440 while *pos < bytes.len() && bytes[*pos].is_ascii_whitespace() {
441 *pos += 1;
442 }
443}
444
445#[inline]
447fn expect_char(bytes: &[u8], pos: &mut usize, expected: u8) -> Result<()> {
448 if *pos >= bytes.len() || bytes[*pos] != expected {
449 bail!(
450 "Expected '{}' at position {}, found '{}'",
451 expected as char,
452 *pos,
453 bytes.get(*pos).map(|&b| b as char).unwrap_or('\0')
454 );
455 }
456 *pos += 1;
457 Ok(())
458}
459
460#[inline]
469fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
470 let num_floats = bytes.len() / 4;
471 let mut result = vec![0.0f32; num_floats];
472
473 let chunks = num_floats / 4;
475 for i in 0..chunks {
476 let base = i * 16;
477 result[i * 4] = f32::from_le_bytes([
478 bytes[base],
479 bytes[base + 1],
480 bytes[base + 2],
481 bytes[base + 3],
482 ]);
483 result[i * 4 + 1] = f32::from_le_bytes([
484 bytes[base + 4],
485 bytes[base + 5],
486 bytes[base + 6],
487 bytes[base + 7],
488 ]);
489 result[i * 4 + 2] = f32::from_le_bytes([
490 bytes[base + 8],
491 bytes[base + 9],
492 bytes[base + 10],
493 bytes[base + 11],
494 ]);
495 result[i * 4 + 3] = f32::from_le_bytes([
496 bytes[base + 12],
497 bytes[base + 13],
498 bytes[base + 14],
499 bytes[base + 15],
500 ]);
501 }
502
503 for i in (chunks * 4)..num_floats {
505 let base = i * 4;
506 result[i] = f32::from_le_bytes([
507 bytes[base],
508 bytes[base + 1],
509 bytes[base + 2],
510 bytes[base + 3],
511 ]);
512 }
513
514 result
515}
516
517#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[test]
526 fn test_load_weights() {
527 let path =
529 std::env::var("RWKV_MODEL").unwrap_or_else(|_| "rwkv-10m.safetensors".to_string());
530
531 if !Path::new(&path).exists() {
532 eprintln!("Skipping test: model file not found at {}", path);
533 return;
534 }
535
536 let weights = Weights::load(&path).unwrap();
537
538 assert!(weights.get("model.embeddings.weight").is_some());
540 assert!(weights.get("lm_head.weight").is_some());
541
542 let emb = weights.get("model.embeddings.weight").unwrap();
544 assert_eq!(emb.shape.len(), 2);
545 assert_eq!(emb.shape[0], 256);
546 assert_eq!(emb.shape[1], 256);
547 }
548
549 #[test]
550 fn test_bytes_to_f32() {
551 let bytes = [0x00, 0x00, 0x80, 0x3F]; let result = bytes_to_f32(&bytes);
554 assert_eq!(result.len(), 1);
555 assert!((result[0] - 1.0).abs() < 1e-6);
556
557 let bytes = [
559 0x00, 0x00, 0x80, 0x3F, 0x00, 0x00, 0x00, 0x40, ];
562 let result = bytes_to_f32(&bytes);
563 assert_eq!(result.len(), 2);
564 assert!((result[0] - 1.0).abs() < 1e-6);
565 assert!((result[1] - 2.0).abs() < 1e-6);
566 }
567}