1use anyhow::{Context, Result, bail};
15use serde_json::json;
16use std::fs;
17use std::io::{Cursor, Read, Write};
18use std::path::{Path, PathBuf};
19use std::sync::Arc;
20
21use crate::backends::llm_policy::{
22 self, LlmPolicy, OptimizerKind, PolicyAction, PolicyRuntime, split_method_policy_segments,
23};
24pub mod rwkv7;
26pub use crate::coders;
28pub use crate::coders::CoderType;
30
31use crate::coders::{
32 ANS_TOTAL, ArithmeticDecoder, ArithmeticEncoder, BlockedRansDecoder, BlockedRansEncoder,
33 CDF_TOTAL, Cdf, quantize_pdf_to_cdf_with_buffer, quantize_pdf_to_rans_cdf_with_buffer,
34};
35
36pub use rwkv7::Config;
38pub use rwkv7::Model;
40pub use rwkv7::ScratchBuffers;
42pub use rwkv7::State;
44
45pub const MAGIC: u32 = 0x5a505447;
52
53pub const VERSION: u8 = 2;
55
56pub const VOCAB_SIZE: usize = 256;
59const DEFAULT_FULL_TBPTT_WINDOW: usize = 8;
61const TBPTT_REPLAY_CHUNK: usize = 32;
62fn optimizer_sidecar_path(model_path: &Path) -> PathBuf {
63 model_path.with_extension("opt.safetensors")
64}
65const RWKV_TRAIN_SCOPES: &[&str] = &[
66 "embed",
67 "pre_norm",
68 "attn_norm",
69 "ffn_norm",
70 "attn",
71 "ffn",
72 "head",
73 "bias",
74 "all",
75 "none",
76];
77
78struct CountingWriter {
79 n: u64,
80}
81
82impl CountingWriter {
83 #[inline]
84 fn new() -> Self {
85 Self { n: 0 }
86 }
87
88 #[inline]
89 fn bytes_written(&self) -> u64 {
90 self.n
91 }
92}
93
94impl Write for CountingWriter {
95 #[inline]
96 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
97 let n = buf.len();
98 self.n = self.n.saturating_add(n as u64);
99 Ok(n)
100 }
101
102 #[inline]
103 fn flush(&mut self) -> std::io::Result<()> {
104 Ok(())
105 }
106}
107
108#[derive(Clone, Copy, Debug, PartialEq, Eq)]
109pub enum OnlineTrainMode {
111 None,
113 Sgd,
115 Adam,
117}
118
119#[derive(Clone, Debug)]
120pub struct OnlineConfig {
122 pub hidden: usize,
124 pub layers: usize,
126 pub intermediate: usize,
128 pub decay_rank: usize,
130 pub a_rank: usize,
132 pub v_rank: usize,
134 pub g_rank: usize,
136 pub seed: u64,
138 pub train_mode: OnlineTrainMode,
140 pub lr: f32,
142 pub stride: usize,
144}
145
146impl Default for OnlineConfig {
147 fn default() -> Self {
148 Self {
149 hidden: 256,
150 layers: 6,
151 intermediate: 1024,
152 decay_rank: 32,
153 a_rank: 32,
154 v_rank: 32,
155 g_rank: 64,
156 seed: 0,
157 train_mode: OnlineTrainMode::None,
158 lr: 0.001,
159 stride: 1,
160 }
161 }
162}
163
164impl OnlineConfig {
165 pub fn to_rwkv_config(&self) -> Result<Config> {
167 let hidden = self.hidden.max(64);
168 if !hidden.is_multiple_of(64) {
169 bail!("rwkv hidden must be a multiple of 64 (got {hidden})");
170 }
171 let num_heads = hidden / 64;
172 let cfg = Config {
173 vocab_size: 256,
174 hidden_size: hidden,
175 num_layers: self.layers.max(1),
176 num_heads,
177 head_dim: 64,
178 intermediate_size: self.intermediate.max(1),
179 layer_norm_eps: 1e-5,
180 group_norm_eps: 64e-5,
181 decay_low_rank: self.decay_rank.max(1),
182 a_low_rank: self.a_rank.max(1),
183 v_low_rank: self.v_rank.max(1),
184 g_low_rank: self.g_rank.max(1),
185 };
186 cfg.validate()?;
187 Ok(cfg)
188 }
189}
190
191#[derive(Clone, Debug)]
192pub enum MethodSpec {
194 File {
196 path: PathBuf,
198 policy: Option<LlmPolicy>,
200 },
201 Online {
203 cfg: OnlineConfig,
205 policy: Option<LlmPolicy>,
207 },
208}
209
210#[derive(Clone)]
211struct OnlineRuntime {
212 cfg: OnlineConfig,
213 canonical_method: String,
214 policy: Option<LlmPolicy>,
215 policy_runtime: Option<PolicyRuntime>,
216 needs_full_trace: bool,
217 policy_stream_total: Option<u64>,
218 policy_train_steps: u64,
219 tokens_processed: u64,
220 out_bias: Vec<f32>,
221 adam_m: Option<Vec<f32>>,
222 adam_v: Option<Vec<f32>>,
223 full_adam: Option<rwkv7::FullAdamState>,
224 lm_head_adam_m: Option<Vec<f32>>,
225 lm_head_adam_v: Option<Vec<f32>>,
226 adam_t: usize,
227 full_tbptt: Option<FullTbpttRuntime>,
228}
229
230#[derive(Clone, Copy, Debug)]
231struct FullTrainSettings {
232 optimizer: OptimizerKind,
233 lr: f32,
234 scope: rwkv7::TrainScopeMask,
235 bptt: usize,
236 clip: f32,
237}
238
239impl FullTrainSettings {
240 fn matches(
241 self,
242 optimizer: OptimizerKind,
243 lr: f32,
244 scope: rwkv7::TrainScopeMask,
245 bptt: usize,
246 clip: f32,
247 ) -> bool {
248 self.optimizer == optimizer
249 && self.lr.to_bits() == lr.to_bits()
250 && self.scope == scope
251 && self.bptt == bptt
252 && self.clip.to_bits() == clip.to_bits()
253 }
254}
255
256#[derive(Clone)]
257struct FullTbpttRuntime {
258 pending_input_token: Option<u32>,
259 pending_input_pre_state: Option<State>,
260 segment_start_state: Option<State>,
261 steps: Vec<(u32, u8)>,
262 settings: Option<FullTrainSettings>,
263}
264
265#[derive(Clone)]
266pub struct RuntimeSnapshot {
268 model: Arc<Model>,
269 scratch: ScratchBuffers,
270 state: State,
271 pdf_buffer: Vec<f64>,
272 online: Option<OnlineRuntime>,
273}
274
275impl OnlineRuntime {
276 fn new(
277 cfg: OnlineConfig,
278 canonical_method: String,
279 policy: Option<LlmPolicy>,
280 vocab_size: usize,
281 hidden_size: usize,
282 ) -> Self {
283 let mut use_adam = matches!(cfg.train_mode, OnlineTrainMode::Adam);
284 if let Some(pol) = &policy {
285 use_adam = policy_uses_adam(pol) || use_adam;
286 }
287 let needs_full_trace = policy
288 .as_ref()
289 .map(policy_needs_full_trace)
290 .unwrap_or(false);
291 Self {
292 canonical_method,
293 cfg,
294 policy,
295 policy_runtime: None,
296 needs_full_trace,
297 policy_stream_total: None,
298 policy_train_steps: 0,
299 tokens_processed: 0,
300 out_bias: vec![0.0; vocab_size],
301 adam_m: use_adam.then(|| vec![0.0; vocab_size]),
302 adam_v: use_adam.then(|| vec![0.0; vocab_size]),
303 full_adam: None,
304 lm_head_adam_m: use_adam.then(|| vec![0.0; vocab_size * hidden_size]),
305 lm_head_adam_v: use_adam.then(|| vec![0.0; vocab_size * hidden_size]),
306 adam_t: 0,
307 full_tbptt: needs_full_trace.then(|| FullTbpttRuntime {
308 pending_input_token: None,
309 pending_input_pre_state: None,
310 segment_start_state: None,
311 steps: Vec::new(),
312 settings: None,
313 }),
314 }
315 }
316
317 fn prepare_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
318 self.policy_stream_total = total_symbols;
319 self.policy_train_steps = 0;
320 if let Some(tbptt) = self.full_tbptt.as_mut() {
321 tbptt.segment_start_state = None;
324 tbptt.steps.clear();
325 tbptt.settings = None;
326 }
327 self.policy_runtime = match &self.policy {
328 Some(p) => Some(PolicyRuntime::new(p.compile(total_symbols)?)),
329 None => None,
330 };
331 Ok(())
332 }
333
334 #[inline]
335 fn next_policy_action(&mut self) -> Result<Option<PolicyAction>> {
336 if self.policy.is_none() {
337 return Ok(None);
338 }
339 if self.policy_runtime.is_none() {
340 self.prepare_policy_stream(None)?;
341 }
342 Ok(self.policy_runtime.as_mut().map(PolicyRuntime::next_action))
343 }
344}
345
346#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
347fn apply_online_lm_head_update(
348 model: &mut Model,
349 online: &mut OnlineRuntime,
350 hidden: &[f32],
351 symbol: u8,
352 pdf: &[f64],
353 lr: f32,
354 optimizer: OptimizerKind,
355 train_head: bool,
356 train_bias: bool,
357 clip: f32,
358) {
359 let h = hidden.len();
360 if h == 0 {
361 return;
362 }
363
364 let head = model.lm_head_weights_mut();
365 let vocab_rows = head.len() / h;
366 let n = online.out_bias.len().min(pdf.len()).min(vocab_rows);
367
368 match optimizer {
369 OptimizerKind::Sgd => {
370 for (i, p_raw) in pdf.iter().enumerate().take(n) {
371 let p = (*p_raw).clamp(1e-12, 1.0) as f32;
372 let target = if i == symbol as usize { 1.0 } else { 0.0 };
373 let mut grad = target - p;
374 if clip > 0.0 {
375 grad = grad.clamp(-clip, clip);
376 }
377 if train_bias {
378 online.out_bias[i] += lr * grad;
379 }
380
381 if train_head {
382 let row_off = i * h;
383 for j in 0..h {
384 head[row_off + j] += lr * grad * hidden[j];
385 }
386 }
387 }
388 }
389 OptimizerKind::Adam => {
390 online.adam_t = online.adam_t.saturating_add(1);
391 let t = online.adam_t as i32;
392 let b1 = 0.9f32;
393 let b2 = 0.999f32;
394 let eps = 1e-8f32;
395 let bias_corr1 = 1.0 - b1.powi(t);
396 let bias_corr2 = 1.0 - b2.powi(t);
397 if online.adam_m.is_none() || online.adam_v.is_none() {
398 online.adam_m = Some(vec![0.0; online.out_bias.len()]);
399 online.adam_v = Some(vec![0.0; online.out_bias.len()]);
400 }
401 if online.lm_head_adam_m.is_none() || online.lm_head_adam_v.is_none() {
402 online.lm_head_adam_m = Some(vec![0.0; vocab_rows * h]);
403 online.lm_head_adam_v = Some(vec![0.0; vocab_rows * h]);
404 }
405 let bm = online.adam_m.as_mut().expect("adam_m initialized");
406 let bv = online.adam_v.as_mut().expect("adam_v initialized");
407 let hm = online
408 .lm_head_adam_m
409 .as_mut()
410 .expect("lm_head_adam_m initialized");
411 let hv = online
412 .lm_head_adam_v
413 .as_mut()
414 .expect("lm_head_adam_v initialized");
415 for i in 0..n {
416 let p = pdf[i].clamp(1e-12, 1.0) as f32;
417 let target = if i == symbol as usize { 1.0 } else { 0.0 };
418 let mut grad = target - p;
419 if clip > 0.0 {
420 grad = grad.clamp(-clip, clip);
421 }
422
423 if train_bias {
424 bm[i] = b1 * bm[i] + (1.0 - b1) * grad;
425 bv[i] = b2 * bv[i] + (1.0 - b2) * grad * grad;
426 let m_hat = bm[i] / bias_corr1;
427 let v_hat = bv[i] / bias_corr2;
428 online.out_bias[i] += lr * m_hat / (v_hat.sqrt() + eps);
429 }
430
431 if train_head {
432 let row_off = i * h;
433 for j in 0..h {
434 let idx = row_off + j;
435 let g = grad * hidden[j];
436 hm[idx] = b1 * hm[idx] + (1.0 - b1) * g;
437 hv[idx] = b2 * hv[idx] + (1.0 - b2) * g * g;
438 let m_hat_w = hm[idx] / bias_corr1;
439 let v_hat_w = hv[idx] / bias_corr2;
440 head[idx] += lr * m_hat_w / (v_hat_w.sqrt() + eps);
441 }
442 }
443 }
444 }
445 }
446}
447
448fn policy_uses_adam(policy: &LlmPolicy) -> bool {
449 use llm_policy::ScheduleRule;
450 for rule in &policy.schedule {
451 match rule {
452 ScheduleRule::Interval(interval) => {
453 if let PolicyAction::Train(train) = &interval.action
454 && matches!(train.optimizer, OptimizerKind::Adam)
455 {
456 return true;
457 }
458 }
459 ScheduleRule::Repeat(repeat) => {
460 for seg in &repeat.pattern {
461 if let PolicyAction::Train(train) = &seg.action
462 && matches!(train.optimizer, OptimizerKind::Adam)
463 {
464 return true;
465 }
466 }
467 }
468 }
469 }
470 false
471}
472
473fn scope_needs_full_trace(scope: &llm_policy::TrainScopeSet) -> bool {
474 scope.all
475 || scope.contains("embed")
476 || scope.contains("pre_norm")
477 || scope.contains("attn_norm")
478 || scope.contains("ffn_norm")
479 || scope.contains("attn")
480 || scope.contains("ffn")
481}
482
483fn policy_needs_full_trace(policy: &LlmPolicy) -> bool {
484 use llm_policy::ScheduleRule;
485 for rule in &policy.schedule {
486 match rule {
487 ScheduleRule::Interval(interval) => {
488 if let PolicyAction::Train(train) = &interval.action
489 && scope_needs_full_trace(&train.scope)
490 {
491 return true;
492 }
493 }
494 ScheduleRule::Repeat(repeat) => {
495 for seg in &repeat.pattern {
496 if let PolicyAction::Train(train) = &seg.action
497 && scope_needs_full_trace(&train.scope)
498 {
499 return true;
500 }
501 }
502 }
503 }
504 }
505 false
506}
507
508fn scope_from_train_action(train: &llm_policy::TrainAction) -> rwkv7::TrainScopeMask {
509 if train.scope.all {
510 return rwkv7::TrainScopeMask::all();
511 }
512 rwkv7::TrainScopeMask {
513 embed: train.scope.contains("embed"),
514 pre_norm: train.scope.contains("pre_norm"),
515 attn_norm: train.scope.contains("attn_norm"),
516 ffn_norm: train.scope.contains("ffn_norm"),
517 attn: train.scope.contains("attn"),
518 ffn: train.scope.contains("ffn"),
519 head: train.scope.contains("head"),
520 bias: train.scope.contains("bias"),
521 }
522}
523
524fn cfg_to_method_string(cfg: &OnlineConfig) -> String {
525 let train = match cfg.train_mode {
526 OnlineTrainMode::None => "none",
527 OnlineTrainMode::Sgd => "sgd",
528 OnlineTrainMode::Adam => "adam",
529 };
530 format!(
531 "cfg:hidden={},layers={},intermediate={},decay_rank={},a_rank={},v_rank={},g_rank={},seed={},train={},lr={},stride={}",
532 cfg.hidden,
533 cfg.layers,
534 cfg.intermediate,
535 cfg.decay_rank,
536 cfg.a_rank,
537 cfg.v_rank,
538 cfg.g_rank,
539 cfg.seed,
540 train,
541 cfg.lr,
542 cfg.stride.max(1),
543 )
544}
545
546fn softmax_pdf_floor_with_bias(logits: &[f32], bias: Option<&[f32]>, pdf_out: &mut [f64]) {
547 debug_assert_eq!(logits.len(), pdf_out.len());
548 if let Some(b) = bias {
549 debug_assert_eq!(b.len(), logits.len());
550 }
551 if logits.is_empty() {
552 return;
553 }
554
555 let mut max_logit = f32::NEG_INFINITY;
556 if let Some(b) = bias {
557 for i in 0..logits.len() {
558 let z = logits[i] + b[i];
559 if z > max_logit {
560 max_logit = z;
561 }
562 }
563 } else {
564 for &z in logits {
565 if z > max_logit {
566 max_logit = z;
567 }
568 }
569 }
570
571 let mut sum = 0.0f64;
572 if let Some(b) = bias {
573 for i in 0..logits.len() {
574 let p = ((logits[i] + b[i] - max_logit) as f64).exp();
575 pdf_out[i] = p;
576 sum += p;
577 }
578 } else {
579 for i in 0..logits.len() {
580 let p = ((logits[i] - max_logit) as f64).exp();
581 pdf_out[i] = p;
582 sum += p;
583 }
584 }
585
586 let inv_sum = if sum.is_finite() && sum > 0.0 {
587 1.0 / sum
588 } else {
589 1.0 / (logits.len() as f64)
590 };
591
592 let floor = 1e-12f64;
593 let mut norm = 0.0f64;
594 for p in pdf_out.iter_mut() {
595 *p = (*p * inv_sum).max(floor);
596 norm += *p;
597 }
598 let inv_norm = if norm.is_finite() && norm > 0.0 {
599 1.0 / norm
600 } else {
601 1.0 / (logits.len() as f64)
602 };
603 for p in pdf_out.iter_mut() {
604 *p *= inv_norm;
605 }
606}
607
608fn parse_u64(v: &str, key: &str) -> Result<u64> {
609 v.parse::<u64>()
610 .with_context(|| format!("invalid integer value for '{key}': {v}"))
611}
612
613fn parse_usize(v: &str, key: &str) -> Result<usize> {
614 v.parse::<usize>()
615 .with_context(|| format!("invalid integer value for '{key}': {v}"))
616}
617
618fn parse_f32(v: &str, key: &str) -> Result<f32> {
619 v.parse::<f32>()
620 .with_context(|| format!("invalid float value for '{key}': {v}"))
621}
622
623fn parse_train_mode_token(v: &str) -> Result<OnlineTrainMode> {
624 let code = v.trim().to_ascii_lowercase();
625 match code.as_str() {
626 "0" | "none" | "off" => Ok(OnlineTrainMode::None),
627 "1" | "sgd" => Ok(OnlineTrainMode::Sgd),
628 "2" | "adam" => Ok(OnlineTrainMode::Adam),
629 other => bail!("unknown train mode '{other}'"),
630 }
631}
632
633fn parse_cfg_positional(csv: &str) -> Result<OnlineConfig> {
634 let vals: Vec<&str> = csv
635 .split(',')
636 .map(|s| s.trim())
637 .filter(|s| !s.is_empty())
638 .collect();
639 if vals.len() != 6 && vals.len() != 7 {
640 bail!(
641 "positional cfg format expects 6 or 7 values: hidden,intermediate,layers,train,seed,lr[,stride]"
642 );
643 }
644
645 let cfg = OnlineConfig {
646 hidden: parse_usize(vals[0], "hidden")?,
647 intermediate: parse_usize(vals[1], "intermediate")?,
648 layers: parse_usize(vals[2], "layers")?,
649 train_mode: parse_train_mode_token(vals[3])?,
650 seed: parse_u64(vals[4], "seed")?,
651 lr: parse_f32(vals[5], "lr")?,
652 stride: if vals.len() == 7 {
653 parse_usize(vals[6], "stride")?
654 } else {
655 1
656 },
657 ..OnlineConfig::default()
658 };
659 Ok(cfg)
660}
661
662pub fn parse_method_spec(method: &str) -> Result<MethodSpec> {
671 let (base, policy_segment) = split_method_policy_segments(method)?;
672 let parse_policy = |s: &str| llm_policy::parse_policy_segment(s, RWKV_TRAIN_SCOPES);
673 let policy = policy_segment
674 .as_deref()
675 .map(parse_policy)
676 .transpose()
677 .context("failed to parse rwkv policy segment")?;
678
679 if let Some(path) = base.strip_prefix("file:") {
680 let p = PathBuf::from(path.trim());
681 if p.as_os_str().is_empty() {
682 bail!("empty file path in rwkv method");
683 }
684 if policy.as_ref().and_then(|p| p.load_from.as_ref()).is_some() {
685 bail!("rwkv method cannot use policy load_from together with file:<path>");
686 }
687 return Ok(MethodSpec::File { path: p, policy });
688 }
689
690 if let Some(cfg_s) = base.strip_prefix("cfg:") {
691 if !cfg_s.contains('=') {
692 return Ok(MethodSpec::Online {
693 cfg: parse_cfg_positional(cfg_s)?,
694 policy,
695 });
696 }
697 let mut cfg = OnlineConfig::default();
698 for pair in cfg_s.split(',') {
699 let pair = pair.trim();
700 if pair.is_empty() {
701 continue;
702 }
703 let (k, v) = pair
704 .split_once('=')
705 .with_context(|| format!("invalid cfg key/value pair '{pair}'"))?;
706 let key = k.trim().to_ascii_lowercase();
707 let val = v.trim();
708 match key.as_str() {
709 "hidden" => cfg.hidden = parse_usize(val, "hidden")?,
710 "layers" => cfg.layers = parse_usize(val, "layers")?,
711 "intermediate" => cfg.intermediate = parse_usize(val, "intermediate")?,
712 "decay_rank" => cfg.decay_rank = parse_usize(val, "decay_rank")?,
713 "a_rank" => cfg.a_rank = parse_usize(val, "a_rank")?,
714 "v_rank" => cfg.v_rank = parse_usize(val, "v_rank")?,
715 "g_rank" => cfg.g_rank = parse_usize(val, "g_rank")?,
716 "seed" => cfg.seed = parse_u64(val, "seed")?,
717 "lr" => cfg.lr = parse_f32(val, "lr")?,
718 "stride" => cfg.stride = parse_usize(val, "stride")?,
719 "train" | "train_mode" => cfg.train_mode = parse_train_mode_token(val)?,
720 other => bail!("unknown rwkv cfg key '{other}'"),
721 }
722 }
723 return Ok(MethodSpec::Online { cfg, policy });
724 }
725
726 let plain = PathBuf::from(base.trim());
727 if plain.exists() {
728 if policy.as_ref().and_then(|p| p.load_from.as_ref()).is_some() {
729 bail!("rwkv method cannot use policy load_from together with file path");
730 }
731 return Ok(MethodSpec::File {
732 path: plain,
733 policy,
734 });
735 }
736
737 if base.contains(',') {
738 return Ok(MethodSpec::Online {
739 cfg: parse_cfg_positional(&base)?,
740 policy,
741 });
742 }
743
744 bail!(
745 "rwkv method must be 'file:<path>', 'cfg:<k=v,...>', positional cfg CSV, or an existing model path"
746 );
747}
748
749#[derive(Debug, Clone)]
762pub struct Header {
763 pub magic: u32,
765 pub version: u8,
767 pub coder: u8,
769 pub original_len: u64,
771 pub crc32: u32,
773}
774
775impl Header {
776 pub const SIZE: usize = 4 + 1 + 1 + 8 + 4; pub fn new(coder: CoderType, original_len: u64, crc32: u32) -> Self {
781 Self {
782 magic: MAGIC,
783 version: VERSION,
784 coder: match coder {
785 CoderType::AC => 0,
786 CoderType::RANS => 1,
787 },
788 original_len,
789 crc32,
790 }
791 }
792
793 pub fn write<W: Write>(&self, w: &mut W) -> Result<()> {
795 w.write_all(&self.magic.to_le_bytes())?;
796 w.write_all(&[self.version])?;
797 w.write_all(&[self.coder])?;
798 w.write_all(&self.original_len.to_le_bytes())?;
799 w.write_all(&self.crc32.to_le_bytes())?;
800 Ok(())
801 }
802
803 pub fn read<R: Read>(r: &mut R) -> Result<Self> {
805 let mut buf4 = [0u8; 4];
806 let mut buf8 = [0u8; 8];
807 let mut buf1 = [0u8; 1];
808
809 r.read_exact(&mut buf4)?;
810 let magic = u32::from_le_bytes(buf4);
811 if magic != MAGIC {
812 bail!(
813 "Invalid magic number: expected 0x{:08X}, got 0x{:08X}",
814 MAGIC,
815 magic
816 );
817 }
818
819 r.read_exact(&mut buf1)?;
820 let version = buf1[0];
821 if version > VERSION {
822 bail!(
823 "Unsupported version: {} (max supported: {})",
824 version,
825 VERSION
826 );
827 }
828
829 r.read_exact(&mut buf1)?;
830 let coder = buf1[0];
831
832 r.read_exact(&mut buf8)?;
833 let original_len = u64::from_le_bytes(buf8);
834
835 r.read_exact(&mut buf4)?;
836 let crc32 = u32::from_le_bytes(buf4);
837
838 Ok(Self {
839 magic,
840 version,
841 coder,
842 original_len,
843 crc32,
844 })
845 }
846
847 pub fn coder_type(&self) -> CoderType {
849 match self.coder {
850 0 => CoderType::AC,
851 _ => CoderType::RANS,
852 }
853 }
854}
855
856pub fn crc32(data: &[u8]) -> u32 {
864 crate::coders::crc32(data)
865}
866
867pub struct Compressor {
876 pub model: Arc<Model>,
878 pub state: State,
880 pub scratch: ScratchBuffers,
882 pub pdf_buffer: Vec<f64>,
884 pub cdf_buffer_ac: Vec<u32>,
886 pub ac_freq_buffer: Vec<i64>,
888 pub cdf_buffer_rans: Vec<u32>,
890 pub rans_freq_buffer: Vec<i64>,
892 online: Option<OnlineRuntime>,
893 source_model_path: Option<PathBuf>,
894}
895
896impl Clone for Compressor {
897 fn clone(&self) -> Self {
898 let mut cloned = Self::new_from_model(self.model.clone());
899 cloned.state = self.state.clone();
900 cloned.pdf_buffer.clone_from(&self.pdf_buffer);
901 cloned.cdf_buffer_ac.clone_from(&self.cdf_buffer_ac);
902 cloned.ac_freq_buffer.clone_from(&self.ac_freq_buffer);
903 cloned.cdf_buffer_rans.clone_from(&self.cdf_buffer_rans);
904 cloned.rans_freq_buffer.clone_from(&self.rans_freq_buffer);
905 cloned.scratch = self.scratch.clone();
906 cloned.online = self.online.clone();
907 cloned.source_model_path = self.source_model_path.clone();
908 cloned
909 }
910}
911
912impl Compressor {
913 pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
921 let model_path = model_path.as_ref();
922 let model = Arc::new(Model::load(model_path)?);
923 let mut c = Self::new_from_model(model);
924 c.source_model_path = Some(model_path.to_path_buf());
925 c.maybe_load_sidecar()?;
926 Ok(c)
927 }
928
929 pub fn load_model<P: AsRef<Path>>(model_path: P) -> Result<Arc<Model>> {
931 Ok(Arc::new(Model::load(model_path)?))
932 }
933
934 pub fn new_from_model(model: Arc<Model>) -> Self {
936 let state = model.new_state();
937 let vocab_size = model.config().vocab_size;
938 let scratch = ScratchBuffers::new(model.config());
939 Self {
940 model,
941 state,
942 scratch,
943 pdf_buffer: vec![0.0f64; vocab_size],
944 cdf_buffer_ac: vec![0u32; vocab_size + 1],
945 ac_freq_buffer: vec![0i64; vocab_size],
946 cdf_buffer_rans: vec![0u32; vocab_size + 1],
947 rans_freq_buffer: vec![0i64; vocab_size],
948 online: None,
949 source_model_path: None,
950 }
951 }
952
953 pub fn new_from_method(method: &str) -> Result<Self> {
955 match parse_method_spec(method)? {
956 MethodSpec::File { path, policy } => {
957 let mut c = Self::new(&path)?;
958 if let Some(policy) = policy {
959 let canonical_method =
960 format!("file:{};policy:{}", path.display(), policy.canonical());
961 let hidden = c.model.config().hidden_size;
962 let mut online = c.online.take().unwrap_or_else(|| {
963 OnlineRuntime::new(
964 OnlineConfig::default(),
965 canonical_method.clone(),
966 Some(policy.clone()),
967 VOCAB_SIZE,
968 hidden,
969 )
970 });
971 online.canonical_method = canonical_method;
972 online.policy = Some(policy);
973 online.needs_full_trace = online
974 .policy
975 .as_ref()
976 .map(policy_needs_full_trace)
977 .unwrap_or(false);
978 c.online = Some(online);
979 c.scratch.set_capture_train_trace(
980 c.online.as_ref().is_some_and(|o| o.needs_full_trace),
981 );
982 }
983 Ok(c)
984 }
985 MethodSpec::Online { cfg, policy } => {
986 let rwcfg = cfg.to_rwkv_config()?;
987 let model = if let Some(load_from) =
988 policy.as_ref().and_then(|p| p.load_from.as_ref())
989 {
990 let loaded = Arc::new(Model::load(load_from)?);
991 let loaded_cfg = loaded.config();
992 let shape_ok = loaded_cfg.vocab_size == rwcfg.vocab_size
993 && loaded_cfg.hidden_size == rwcfg.hidden_size
994 && loaded_cfg.num_layers == rwcfg.num_layers
995 && loaded_cfg.num_heads == rwcfg.num_heads
996 && loaded_cfg.head_dim == rwcfg.head_dim
997 && loaded_cfg.intermediate_size == rwcfg.intermediate_size
998 && loaded_cfg.decay_low_rank == rwcfg.decay_low_rank
999 && loaded_cfg.a_low_rank == rwcfg.a_low_rank
1000 && loaded_cfg.v_low_rank == rwcfg.v_low_rank
1001 && loaded_cfg.g_low_rank == rwcfg.g_low_rank;
1002 if !shape_ok {
1003 bail!(
1004 "rwkv policy load_from shape mismatch with cfg (strict match required)"
1005 );
1006 }
1007 loaded
1008 } else {
1009 Arc::new(Model::new_random(rwcfg, cfg.seed)?)
1010 };
1011 let mut c = Self::new_from_model(model);
1012 let mut canonical_method = cfg_to_method_string(&cfg);
1013 if let Some(policy) = policy.as_ref() {
1014 canonical_method.push_str(";policy:");
1015 canonical_method.push_str(&policy.canonical());
1016 }
1017 c.online = Some(OnlineRuntime::new(
1018 cfg,
1019 canonical_method,
1020 policy,
1021 VOCAB_SIZE,
1022 c.model.config().hidden_size,
1023 ));
1024 c.scratch
1025 .set_capture_train_trace(c.online.as_ref().is_some_and(|o| o.needs_full_trace));
1026 Ok(c)
1027 }
1028 }
1029 }
1030
1031 pub fn reset(&mut self) {
1036 self.state.reset();
1037 self.clear_online_training_buffers();
1038 }
1039
1040 fn prepare_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1041 if let Some(online) = self.online.as_mut() {
1042 online.prepare_policy_stream(total_symbols)?;
1043 }
1044 Ok(())
1045 }
1046
1047 #[inline]
1048 fn effective_full_bptt(scope: rwkv7::TrainScopeMask, bptt: usize) -> usize {
1049 if scope.trains_non_head_params() && bptt <= 1 {
1050 DEFAULT_FULL_TBPTT_WINDOW
1051 } else {
1052 bptt.max(1)
1053 }
1054 }
1055
1056 fn clear_online_training_buffers(&mut self) {
1057 if let Some(online) = self.online.as_mut()
1058 && let Some(tbptt) = online.full_tbptt.as_mut()
1059 {
1060 tbptt.pending_input_token = None;
1061 tbptt.pending_input_pre_state = None;
1062 tbptt.segment_start_state = None;
1063 tbptt.steps.clear();
1064 tbptt.settings = None;
1065 }
1066 }
1067
1068 fn forward_with_online_record(&mut self, token: u32) {
1069 if let Some(online) = self.online.as_mut()
1070 && let Some(tbptt) = online.full_tbptt.as_mut()
1071 {
1072 tbptt.pending_input_token = Some(token);
1073 tbptt.pending_input_pre_state = Some(self.state.clone());
1074 }
1075 let _ = self
1076 .model
1077 .forward(&mut self.scratch, token, &mut self.state);
1078 }
1079
1080 fn flush_full_tbptt_segment(&mut self) -> Result<()> {
1081 let extracted = {
1082 match self.online.as_mut() {
1083 Some(online) => match online.full_tbptt.as_mut() {
1084 Some(tbptt) if !tbptt.steps.is_empty() => {
1085 let settings = tbptt.settings.ok_or_else(|| {
1086 anyhow::anyhow!("rwkv full tbptt settings are missing")
1087 })?;
1088 let start_state = tbptt.segment_start_state.clone().ok_or_else(|| {
1089 anyhow::anyhow!("rwkv full tbptt segment start is missing")
1090 })?;
1091 let steps = tbptt.steps.clone();
1092 tbptt.steps.clear();
1093 tbptt.segment_start_state = None;
1094 tbptt.settings = None;
1095 let need_full_adam = matches!(settings.optimizer, OptimizerKind::Adam)
1096 && settings.scope.trains_non_head_params()
1097 && online.full_adam.is_none();
1098 Some((settings, start_state, steps, need_full_adam))
1099 }
1100 _ => None,
1101 },
1102 None => None,
1103 }
1104 };
1105 let Some((settings, start_state, steps, need_full_adam)) = extracted else {
1106 return Ok(());
1107 };
1108
1109 if need_full_adam {
1110 let full_adam = self.model.new_full_adam_state();
1111 if let Some(online) = self.online.as_mut() {
1112 online.full_adam = Some(full_adam);
1113 }
1114 }
1115
1116 let model = Arc::make_mut(&mut self.model);
1117 let Some(online) = self.online.as_mut() else {
1118 return Ok(());
1119 };
1120 model.online_train_segment_tbptt(
1121 &mut self.scratch,
1122 &start_state,
1123 &steps,
1124 settings.scope,
1125 settings.optimizer,
1126 settings.lr,
1127 settings.clip,
1128 TBPTT_REPLAY_CHUNK,
1129 &mut online.adam_t,
1130 online.full_adam.as_mut(),
1131 if settings.scope.bias {
1132 Some(online.out_bias.as_mut_slice())
1133 } else {
1134 None
1135 },
1136 if settings.scope.bias {
1137 online.adam_m.as_deref_mut()
1138 } else {
1139 None
1140 },
1141 if settings.scope.bias {
1142 online.adam_v.as_deref_mut()
1143 } else {
1144 None
1145 },
1146 &mut self.state,
1147 )?;
1148 let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1149 Self::logits_to_pdf(self.scratch.logits(), bias, &mut self.pdf_buffer);
1150 Ok(())
1151 }
1152
1153 fn enqueue_full_tbptt_step(
1154 &mut self,
1155 settings: FullTrainSettings,
1156 target_symbol: u8,
1157 ) -> Result<()> {
1158 let should_flush = {
1159 let Some(online) = self.online.as_mut() else {
1160 return Ok(());
1161 };
1162 let Some(tbptt) = online.full_tbptt.as_mut() else {
1163 bail!("rwkv full-parameter online training requires trace-enabled tbptt runtime");
1164 };
1165 tbptt.settings.is_some_and(|prev| {
1166 !prev.matches(
1167 settings.optimizer,
1168 settings.lr,
1169 settings.scope,
1170 settings.bptt,
1171 settings.clip,
1172 )
1173 }) && !tbptt.steps.is_empty()
1174 };
1175 if should_flush {
1176 self.flush_full_tbptt_segment()?;
1177 }
1178
1179 let flush_now = {
1180 let Some(online) = self.online.as_mut() else {
1181 return Ok(());
1182 };
1183 let Some(tbptt) = online.full_tbptt.as_mut() else {
1184 bail!("rwkv full-parameter online training requires trace-enabled tbptt runtime");
1185 };
1186 let Some(input_token) = tbptt.pending_input_token.take() else {
1187 return Ok(());
1188 };
1189 let input_pre_state = tbptt
1190 .pending_input_pre_state
1191 .take()
1192 .ok_or_else(|| anyhow::anyhow!("rwkv full tbptt pending pre-state is missing"))?;
1193 if tbptt.steps.is_empty() {
1194 tbptt.segment_start_state = Some(input_pre_state);
1195 }
1196 tbptt.settings = Some(settings);
1197 tbptt.steps.push((input_token, target_symbol));
1198 tbptt.steps.len() >= settings.bptt.max(1)
1199 };
1200 if flush_now {
1201 self.flush_full_tbptt_segment()?;
1202 }
1203 Ok(())
1204 }
1205
1206 pub fn begin_online_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1208 self.finish_online_policy_stream()?;
1209 self.prepare_policy_stream(total_symbols)
1210 }
1211
1212 pub fn finish_online_policy_stream(&mut self) -> Result<()> {
1214 self.flush_full_tbptt_segment()
1215 }
1216
1217 pub fn restart_online_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1219 self.finish_online_policy_stream()?;
1220 self.state.reset();
1221 self.clear_online_training_buffers();
1222 self.prepare_policy_stream(total_symbols)
1223 }
1224
1225 pub fn reset_and_prime(&mut self) {
1227 self.state.reset();
1228 self.clear_online_training_buffers();
1229 self.refresh_current_pdf(0);
1230 }
1231
1232 pub fn snapshot_runtime(&self) -> RuntimeSnapshot {
1234 RuntimeSnapshot {
1235 model: self.model.clone(),
1236 scratch: self.scratch.clone(),
1237 state: self.state.clone(),
1238 pdf_buffer: self.pdf_buffer.clone(),
1239 online: self.online.clone(),
1240 }
1241 }
1242
1243 pub fn restore_runtime(&mut self, snapshot: &RuntimeSnapshot) {
1245 self.model = snapshot.model.clone();
1246 self.scratch = snapshot.scratch.clone();
1247 self.state = snapshot.state.clone();
1248 self.pdf_buffer.clone_from(&snapshot.pdf_buffer);
1249 self.online = snapshot.online.clone();
1250 }
1251
1252 pub fn absorb_chain(&mut self, parts: &[&[u8]]) -> Result<()> {
1254 let total = parts
1255 .iter()
1256 .fold(0u64, |acc, part| acc.saturating_add(part.len() as u64));
1257 self.fit_chain(parts, Some(total))
1258 }
1259
1260 pub fn cross_entropy_from_current(&mut self, data: &[u8]) -> Result<f64> {
1262 if data.is_empty() {
1263 return Ok(0.0);
1264 }
1265 self.begin_online_policy_stream(Some(data.len() as u64))?;
1266 let mut total_bits = 0.0f64;
1267 for &byte in data {
1268 total_bits -= self.pdf_buffer[byte as usize].log2();
1269 self.observe_symbol_from_current_pdf(byte)?;
1270 }
1271 self.finish_online_policy_stream()?;
1272 Ok(total_bits / (data.len() as f64))
1273 }
1274
1275 pub fn cross_entropy_frozen_plugin_chain(
1277 &mut self,
1278 fit_parts: &[&[u8]],
1279 data: &[u8],
1280 ) -> Result<f64> {
1281 if data.is_empty() {
1282 return Ok(0.0);
1283 }
1284 if !self.can_adapt_online() {
1285 return self.cross_entropy(data);
1286 }
1287 self.finish_online_policy_stream()?;
1288 self.reset_and_prime();
1289 let fit_total = fit_parts
1290 .iter()
1291 .fold(0u64, |acc, part| acc.saturating_add(part.len() as u64));
1292 self.fit_chain(fit_parts, Some(fit_total))?;
1293 self.reset_and_prime();
1294
1295 let mut total_bits = 0.0f64;
1296 for &byte in data {
1297 total_bits -= self.pdf_buffer[byte as usize].max(1e-300).log2();
1298 self.advance_inference_only(byte);
1299 }
1300 Ok(total_bits / (data.len() as f64))
1301 }
1302
1303 pub fn is_online(&self) -> bool {
1305 self.online.is_some()
1306 }
1307
1308 pub fn can_adapt_online(&self) -> bool {
1310 let Some(online) = &self.online else {
1311 return false;
1312 };
1313 match &online.policy {
1314 Some(policy) => llm_policy::policy_can_train(policy),
1315 None => !matches!(online.cfg.train_mode, OnlineTrainMode::None),
1316 }
1317 }
1318
1319 pub fn tokens_processed(&self) -> u64 {
1321 self.online.as_ref().map_or(0, |s| s.tokens_processed)
1322 }
1323
1324 pub fn online_method_string(&self) -> Option<&str> {
1326 self.online.as_ref().map(|s| s.canonical_method.as_str())
1327 }
1328
1329 pub fn vocab_size(&self) -> usize {
1331 self.model.config().vocab_size
1332 }
1333
1334 pub fn online_apply_logits_bias(&self, logits: &[f32], pdf_out: &mut [f64]) {
1336 let bias = self.online.as_ref().map(|s| s.out_bias.as_slice());
1337 Self::logits_to_pdf(logits, bias, pdf_out);
1338 }
1339
1340 pub fn logits_to_pdf(logits: &[f32], bias: Option<&[f32]>, pdf_out: &mut [f64]) {
1342 softmax_pdf_floor_with_bias(logits, bias, pdf_out);
1343 }
1344
1345 #[inline]
1346 pub fn forward_to_pdf(&mut self, token: u32, pdf_out: &mut [f64]) {
1348 self.forward_with_online_record(token);
1349 let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1350 Self::logits_to_pdf(self.scratch.logits(), bias, pdf_out);
1351 }
1352
1353 #[inline]
1354 pub fn forward_to_internal_pdf(&mut self, token: u32) {
1356 self.refresh_current_pdf(token);
1357 }
1358
1359 #[inline]
1360 pub fn copy_current_pdf_to(&self, pdf_out: &mut [f64]) {
1362 assert_eq!(
1363 pdf_out.len(),
1364 self.pdf_buffer.len(),
1365 "rwkv pdf output length mismatch"
1366 );
1367 pdf_out.copy_from_slice(&self.pdf_buffer);
1368 }
1369
1370 pub fn online_bias_snapshot(&self) -> Option<Vec<f32>> {
1372 self.online.as_ref().map(|o| o.out_bias.clone())
1373 }
1374
1375 #[inline]
1376 pub fn online_bias_slice(&self) -> Option<&[f32]> {
1378 self.online.as_ref().map(|o| o.out_bias.as_slice())
1379 }
1380
1381 #[inline]
1382 fn refresh_current_pdf(&mut self, token: u32) {
1383 self.forward_with_online_record(token);
1384 let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1385 Self::logits_to_pdf(self.scratch.logits(), bias, &mut self.pdf_buffer);
1386 }
1387
1388 fn fit_chain(&mut self, parts: &[&[u8]], total_symbols: Option<u64>) -> Result<()> {
1389 self.begin_online_policy_stream(total_symbols)?;
1390 for part in parts {
1391 for &byte in *part {
1392 self.observe_symbol_from_current_pdf(byte)?;
1393 }
1394 }
1395 self.finish_online_policy_stream()?;
1396 Ok(())
1397 }
1398
1399 #[inline]
1400 fn advance_inference_only(&mut self, symbol: u8) {
1401 self.refresh_current_pdf(symbol as u32);
1402 }
1403
1404 fn resolve_online_train_action(
1405 online: &mut OnlineRuntime,
1406 ) -> Result<(OptimizerKind, f32, u64, rwkv7::TrainScopeMask, usize, f32)> {
1407 let mut optimizer = match online.cfg.train_mode {
1408 OnlineTrainMode::None => OptimizerKind::Sgd,
1409 OnlineTrainMode::Sgd => OptimizerKind::Sgd,
1410 OnlineTrainMode::Adam => OptimizerKind::Adam,
1411 };
1412 let mut lr = online.cfg.lr.max(0.0);
1413 let mut stride = online.cfg.stride.max(1) as u64;
1414 let mut scope = rwkv7::TrainScopeMask::default();
1415 let default_train = !matches!(online.cfg.train_mode, OnlineTrainMode::None);
1416 scope.head = default_train;
1417 scope.bias = default_train;
1418 let mut bptt = 1usize;
1419 let mut clip = 0.0f32;
1420
1421 if let Some(action) = online.next_policy_action()? {
1422 match action {
1423 PolicyAction::Infer => {
1424 scope = rwkv7::TrainScopeMask::default();
1425 }
1426 PolicyAction::Train(train) => {
1427 optimizer = train.optimizer;
1428 lr = train.hyper.lr.max(0.0);
1429 stride = train.hyper.stride.max(1) as u64;
1430 bptt = train.hyper.bptt.max(1);
1431 clip = train.hyper.clip.max(0.0);
1432 scope = scope_from_train_action(&train);
1433 }
1434 }
1435 }
1436
1437 Ok((optimizer, lr, stride, scope, bptt, clip))
1438 }
1439
1440 pub fn online_update_from_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1442 self.online_update_with_pdf(symbol, pdf)
1443 }
1444
1445 #[inline]
1446 pub fn observe_symbol_from_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1448 self.online_update_with_pdf(symbol, pdf)?;
1449 self.refresh_current_pdf(symbol as u32);
1450 Ok(())
1451 }
1452
1453 fn online_update_with_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1454 let (optimizer, lr, stride_hit, scope, bptt, clip) = {
1455 let Some(online) = self.online.as_mut() else {
1456 return Ok(());
1457 };
1458 online.tokens_processed = online.tokens_processed.saturating_add(1);
1459 let (optimizer, lr, stride, scope, bptt, clip) =
1460 Self::resolve_online_train_action(online)?;
1461 let mut stride_hit = false;
1462 if scope.trains_any_params() {
1463 online.policy_train_steps = online.policy_train_steps.saturating_add(1);
1464 stride_hit = stride <= 1 || (online.policy_train_steps % stride) == 0;
1465 }
1466 (optimizer, lr, stride_hit, scope, bptt, clip)
1467 };
1468
1469 if !scope.trains_any_params() || !stride_hit || lr == 0.0 {
1470 self.flush_full_tbptt_segment()?;
1471 if let Some(online) = self.online.as_mut()
1472 && let Some(tbptt) = online.full_tbptt.as_mut()
1473 {
1474 tbptt.pending_input_token = None;
1475 tbptt.pending_input_pre_state = None;
1476 }
1477 return Ok(());
1478 }
1479
1480 if matches!(optimizer, OptimizerKind::Adam)
1481 && let Some(online) = self.online.as_mut()
1482 && scope.bias
1483 && (online.adam_m.is_none() || online.adam_v.is_none())
1484 {
1485 online.adam_m = Some(vec![0.0; online.out_bias.len()]);
1486 online.adam_v = Some(vec![0.0; online.out_bias.len()]);
1487 }
1488
1489 if !scope.trains_non_head_params() {
1490 let hidden = self.scratch.lm_head_input().to_vec();
1491 let pdf_snapshot = pdf.to_vec();
1492 self.flush_full_tbptt_segment()?;
1493 let Some(online) = self.online.as_mut() else {
1494 return Ok(());
1495 };
1496 if let Some(tbptt) = online.full_tbptt.as_mut() {
1497 tbptt.pending_input_token = None;
1498 tbptt.pending_input_pre_state = None;
1499 }
1500 let model = Arc::make_mut(&mut self.model);
1501 apply_online_lm_head_update(
1502 model,
1503 online,
1504 &hidden,
1505 symbol,
1506 &pdf_snapshot,
1507 lr,
1508 optimizer,
1509 scope.head,
1510 scope.bias,
1511 clip,
1512 );
1513 return Ok(());
1514 }
1515
1516 let settings = FullTrainSettings {
1517 optimizer,
1518 lr,
1519 scope,
1520 bptt: Self::effective_full_bptt(scope, bptt),
1521 clip,
1522 };
1523 self.enqueue_full_tbptt_step(settings, symbol)
1524 }
1525
1526 fn online_update_from_current_pdf(&mut self, symbol: u8) -> Result<()> {
1527 let pdf_snapshot = self.pdf_buffer.clone();
1528 self.online_update_with_pdf(symbol, &pdf_snapshot)
1529 }
1530
1531 #[inline]
1532 pub fn observe_symbol_from_current_pdf(&mut self, symbol: u8) -> Result<()> {
1534 self.online_update_from_current_pdf(symbol)?;
1535 self.refresh_current_pdf(symbol as u32);
1536 Ok(())
1537 }
1538
1539 pub fn export_online<P: AsRef<Path>>(&self, model_path: P) -> Result<()> {
1541 let model_path = model_path.as_ref();
1542 self.model.save_safetensors(model_path)?;
1543 let opt_sidecar = optimizer_sidecar_path(model_path);
1544
1545 let sidecar = model_path.with_extension("json");
1546 let meta = if let Some(online) = &self.online {
1547 if let Some(full_adam) = online.full_adam.as_ref() {
1548 self.model
1549 .save_full_adam_safetensors(full_adam, &opt_sidecar)?;
1550 } else if opt_sidecar.exists() {
1551 let _ = fs::remove_file(&opt_sidecar);
1552 }
1553 let train_mode = match online.cfg.train_mode {
1554 OnlineTrainMode::None => "none",
1555 OnlineTrainMode::Sgd => "sgd",
1556 OnlineTrainMode::Adam => "adam",
1557 };
1558 json!({
1559 "version": 1,
1560 "method": online.canonical_method,
1561 "policy": online.policy.as_ref().map(LlmPolicy::canonical),
1562 "policy_cursor": online.policy_runtime.as_ref().map(PolicyRuntime::cursor).unwrap_or(0),
1563 "policy_stream_total": online.policy_stream_total,
1564 "policy_train_steps": online.policy_train_steps,
1565 "training_mode": train_mode,
1566 "tokens_processed": online.tokens_processed,
1567 "adam_t": online.adam_t,
1568 "has_full_adam": online.full_adam.is_some(),
1569 "config": {
1570 "hidden": online.cfg.hidden,
1571 "layers": online.cfg.layers,
1572 "intermediate": online.cfg.intermediate,
1573 "decay_rank": online.cfg.decay_rank,
1574 "a_rank": online.cfg.a_rank,
1575 "v_rank": online.cfg.v_rank,
1576 "g_rank": online.cfg.g_rank,
1577 "seed": online.cfg.seed,
1578 "lr": online.cfg.lr,
1579 "stride": online.cfg.stride.max(1),
1580 },
1581 "output_bias": online.out_bias,
1582 "adam_m": online.adam_m,
1583 "adam_v": online.adam_v,
1584 "lm_head_adam_m": online.lm_head_adam_m,
1585 "lm_head_adam_v": online.lm_head_adam_v,
1586 })
1587 } else {
1588 if opt_sidecar.exists() {
1589 let _ = fs::remove_file(&opt_sidecar);
1590 }
1591 json!({
1592 "version": 1,
1593 "method": format!("file:{}", model_path.display()),
1594 "training_mode": "none",
1595 "tokens_processed": 0,
1596 })
1597 };
1598
1599 fs::write(&sidecar, serde_json::to_vec_pretty(&meta)?)?;
1600 Ok(())
1601 }
1602
1603 fn maybe_load_sidecar(&mut self) -> Result<()> {
1604 let Some(model_path) = &self.source_model_path else {
1605 return Ok(());
1606 };
1607 let sidecar = model_path.with_extension("json");
1608 if !sidecar.exists() {
1609 return Ok(());
1610 }
1611 let raw = fs::read(&sidecar)?;
1612 let v: serde_json::Value = serde_json::from_slice(&raw)?;
1613 let parse_vec_f32 = |key: &str| -> Option<Vec<f32>> {
1614 v.get(key).and_then(|arr| arr.as_array()).map(|arr| {
1615 arr.iter()
1616 .map(|x| x.as_f64().unwrap_or(0.0) as f32)
1617 .collect::<Vec<f32>>()
1618 })
1619 };
1620 let output_bias = v
1621 .get("output_bias")
1622 .and_then(|arr| arr.as_array())
1623 .map(|arr| {
1624 arr.iter()
1625 .map(|x| x.as_f64().unwrap_or(0.0) as f32)
1626 .collect::<Vec<f32>>()
1627 });
1628 let method = v
1629 .get("method")
1630 .and_then(|m| m.as_str())
1631 .map(|s| s.to_string())
1632 .unwrap_or_else(|| format!("file:{}", model_path.display()));
1633 let has_full_adam = v
1634 .get("has_full_adam")
1635 .and_then(|x| x.as_bool())
1636 .unwrap_or(false);
1637 let policy = v
1638 .get("policy")
1639 .and_then(|p| p.as_str())
1640 .and_then(|s| llm_policy::parse_policy_segment(s, RWKV_TRAIN_SCOPES).ok());
1641 let tokens = v
1642 .get("tokens_processed")
1643 .and_then(|t| t.as_u64())
1644 .unwrap_or(0);
1645 if let Some(mut out_bias) = output_bias {
1646 out_bias.resize(self.vocab_size(), 0.0);
1647 let mut cfg = OnlineConfig::default();
1648 if let Some(cfg_v) = v.get("config").and_then(|x| x.as_object()) {
1649 if let Some(x) = cfg_v.get("hidden").and_then(|x| x.as_u64()) {
1650 cfg.hidden = x as usize;
1651 }
1652 if let Some(x) = cfg_v.get("layers").and_then(|x| x.as_u64()) {
1653 cfg.layers = x as usize;
1654 }
1655 if let Some(x) = cfg_v.get("intermediate").and_then(|x| x.as_u64()) {
1656 cfg.intermediate = x as usize;
1657 }
1658 if let Some(x) = cfg_v.get("decay_rank").and_then(|x| x.as_u64()) {
1659 cfg.decay_rank = x as usize;
1660 }
1661 if let Some(x) = cfg_v.get("a_rank").and_then(|x| x.as_u64()) {
1662 cfg.a_rank = x as usize;
1663 }
1664 if let Some(x) = cfg_v.get("v_rank").and_then(|x| x.as_u64()) {
1665 cfg.v_rank = x as usize;
1666 }
1667 if let Some(x) = cfg_v.get("g_rank").and_then(|x| x.as_u64()) {
1668 cfg.g_rank = x as usize;
1669 }
1670 if let Some(x) = cfg_v.get("seed").and_then(|x| x.as_u64()) {
1671 cfg.seed = x;
1672 }
1673 if let Some(x) = cfg_v.get("lr").and_then(|x| x.as_f64()) {
1674 cfg.lr = x as f32;
1675 }
1676 if let Some(x) = cfg_v.get("stride").and_then(|x| x.as_u64()) {
1677 cfg.stride = (x as usize).max(1);
1678 }
1679 }
1680 cfg.train_mode = v
1681 .get("training_mode")
1682 .and_then(|x| x.as_str())
1683 .and_then(|s| parse_train_mode_token(s).ok())
1684 .unwrap_or(OnlineTrainMode::None);
1685 let needs_full_trace = policy
1686 .as_ref()
1687 .map(policy_needs_full_trace)
1688 .unwrap_or(false);
1689 self.online = Some(OnlineRuntime {
1690 cfg,
1691 canonical_method: method,
1692 policy,
1693 policy_runtime: None,
1694 needs_full_trace,
1695 policy_stream_total: v.get("policy_stream_total").and_then(|x| x.as_u64()),
1696 policy_train_steps: v
1697 .get("policy_train_steps")
1698 .and_then(|x| x.as_u64())
1699 .unwrap_or(0),
1700 tokens_processed: tokens,
1701 out_bias,
1702 adam_m: parse_vec_f32("adam_m"),
1703 adam_v: parse_vec_f32("adam_v"),
1704 full_adam: None,
1705 lm_head_adam_m: parse_vec_f32("lm_head_adam_m"),
1706 lm_head_adam_v: parse_vec_f32("lm_head_adam_v"),
1707 adam_t: v.get("adam_t").and_then(|x| x.as_u64()).unwrap_or(0) as usize,
1708 full_tbptt: needs_full_trace.then(|| FullTbpttRuntime {
1709 pending_input_token: None,
1710 pending_input_pre_state: None,
1711 segment_start_state: None,
1712 steps: Vec::new(),
1713 settings: None,
1714 }),
1715 });
1716 let opt_sidecar = optimizer_sidecar_path(model_path);
1717 if opt_sidecar.exists() {
1718 if let Some(online) = self.online.as_mut() {
1719 online.full_adam = Some(self.model.load_full_adam_safetensors(&opt_sidecar)?);
1720 }
1721 } else if has_full_adam {
1722 bail!(
1723 "missing optimizer sidecar '{}' required for exact online resume",
1724 opt_sidecar.display()
1725 );
1726 }
1727 if let Some(cursor) = v.get("policy_cursor").and_then(|x| x.as_u64())
1728 && let Some(online) = self.online.as_mut()
1729 && online.policy.is_some()
1730 {
1731 let train_steps = online.policy_train_steps;
1732 online.prepare_policy_stream(online.policy_stream_total)?;
1733 online.policy_train_steps = train_steps;
1734 if let Some(rt) = online.policy_runtime.as_mut() {
1735 rt.set_cursor(cursor);
1736 }
1737 }
1738 self.scratch
1739 .set_capture_train_trace(self.online.as_ref().is_some_and(|o| o.needs_full_trace));
1740 }
1741 Ok(())
1742 }
1743
1744 pub fn compress(&mut self, data: &[u8], coder: CoderType) -> Result<Vec<u8>> {
1753 let mut output = Vec::new();
1754 self.compress_into(data, coder, &mut output)?;
1755 Ok(output)
1756 }
1757
1758 pub fn compress_into<W: Write>(
1760 &mut self,
1761 data: &[u8],
1762 coder: CoderType,
1763 w: &mut W,
1764 ) -> Result<()> {
1765 self.restart_online_policy_stream(Some(data.len() as u64))?;
1766
1767 let checksum = crc32(data);
1768 let header = Header::new(coder, data.len() as u64, checksum);
1769 header.write(w)?;
1770
1771 match coder {
1772 CoderType::AC => self.compress_ac(data, w)?,
1773 CoderType::RANS => self.compress_rans(data, w)?,
1774 }
1775
1776 Ok(())
1777 }
1778
1779 pub fn compress_chain_into<W: Write>(
1781 &mut self,
1782 parts: &[&[u8]],
1783 coder: CoderType,
1784 w: &mut W,
1785 ) -> Result<()> {
1786 let mut total_len: u64 = 0;
1787 let mut hasher = crc32fast::Hasher::new();
1788 for p in parts {
1789 total_len = total_len.saturating_add(p.len() as u64);
1790 hasher.update(p);
1791 }
1792 let checksum = hasher.finalize();
1793 self.restart_online_policy_stream(Some(total_len))?;
1794
1795 let header = Header::new(coder, total_len, checksum);
1796 header.write(w)?;
1797
1798 let it = parts.iter().flat_map(|p| p.iter().copied());
1799 match coder {
1800 CoderType::AC => self.compress_ac_iter(it, w)?,
1801 CoderType::RANS => self.compress_rans_iter(it, w)?,
1802 }
1803
1804 Ok(())
1805 }
1806
1807 pub fn compress_size(&mut self, data: &[u8], coder: CoderType) -> Result<u64> {
1809 let mut w = CountingWriter::new();
1810 self.compress_into(data, coder, &mut w)?;
1811 Ok(w.bytes_written())
1812 }
1813
1814 pub fn compress_size_chain(&mut self, parts: &[&[u8]], coder: CoderType) -> Result<u64> {
1816 let mut w = CountingWriter::new();
1817 self.compress_chain_into(parts, coder, &mut w)?;
1818 Ok(w.bytes_written())
1819 }
1820
1821 fn compress_ac<W: Write>(&mut self, data: &[u8], output: &mut W) -> Result<()> {
1823 self.compress_ac_iter(data.iter().copied(), output)
1824 }
1825
1826 fn compress_ac_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
1827 where
1828 I: IntoIterator<Item = u8>,
1829 {
1830 let mut encoder = ArithmeticEncoder::new(output);
1831
1832 self.refresh_current_pdf(0);
1834
1835 for byte in data {
1836 quantize_pdf_to_cdf_with_buffer(
1837 &self.pdf_buffer,
1838 &mut self.cdf_buffer_ac,
1839 &mut self.ac_freq_buffer,
1840 );
1841 let sym = byte as usize;
1842 let c_lo = self.cdf_buffer_ac[sym] as u64;
1843 let c_hi = self.cdf_buffer_ac[sym + 1] as u64;
1844 encoder.encode_counts(c_lo, c_hi, CDF_TOTAL as u64)?;
1845 self.observe_symbol_from_current_pdf(byte)?;
1846 }
1847
1848 let _ = encoder.finish()?;
1849 self.finish_online_policy_stream()?;
1850 Ok(())
1851 }
1852
1853 fn compress_rans<W: Write>(&mut self, data: &[u8], output: &mut W) -> Result<()> {
1855 self.compress_rans_iter(data.iter().copied(), output)
1856 }
1857
1858 fn compress_rans_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
1859 where
1860 I: IntoIterator<Item = u8>,
1861 {
1862 let mut encoder = BlockedRansEncoder::new();
1864
1865 self.refresh_current_pdf(0);
1867
1868 for byte in data {
1869 quantize_pdf_to_rans_cdf_with_buffer(
1870 &self.pdf_buffer,
1871 &mut self.cdf_buffer_rans,
1872 &mut self.rans_freq_buffer,
1873 );
1874 let sym = byte as usize;
1875 let cdf = Cdf::new(
1876 self.cdf_buffer_rans[sym],
1877 self.cdf_buffer_rans[sym + 1],
1878 ANS_TOTAL,
1879 );
1880 encoder.encode(cdf);
1881 self.observe_symbol_from_current_pdf(byte)?;
1882 }
1883
1884 let blocks = encoder.finish();
1886
1887 output.write_all(&(blocks.len() as u32).to_le_bytes())?;
1889
1890 for block in &blocks {
1892 output.write_all(&(block.len() as u32).to_le_bytes())?;
1893 output.write_all(block)?;
1894 }
1895
1896 self.finish_online_policy_stream()?;
1897 Ok(())
1898 }
1899
1900 pub fn decompress(&mut self, data: &[u8]) -> Result<Vec<u8>> {
1908 let mut cursor = Cursor::new(data);
1909 let header = Header::read(&mut cursor)?;
1910
1911 self.restart_online_policy_stream(Some(header.original_len))?;
1912
1913 let compressed = &data[Header::SIZE..];
1914 let result = match header.coder_type() {
1915 CoderType::AC => self.decompress_ac(compressed, header.original_len as usize)?,
1916 CoderType::RANS => self.decompress_rans(compressed, header.original_len as usize)?,
1917 };
1918
1919 let actual_crc = crc32(&result);
1921 if actual_crc != header.crc32 {
1922 bail!(
1923 "CRC32 mismatch: expected 0x{:08X}, got 0x{:08X}",
1924 header.crc32,
1925 actual_crc
1926 );
1927 }
1928
1929 Ok(result)
1930 }
1931
1932 fn decompress_ac(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
1934 let mut decoder = ArithmeticDecoder::new(compressed)?;
1935
1936 let mut result = Vec::with_capacity(original_len);
1937
1938 self.refresh_current_pdf(0);
1940
1941 for _ in 0..original_len {
1942 quantize_pdf_to_cdf_with_buffer(
1943 &self.pdf_buffer,
1944 &mut self.cdf_buffer_ac,
1945 &mut self.ac_freq_buffer,
1946 );
1947 let sym = decoder.decode_symbol_counts(&self.cdf_buffer_ac, CDF_TOTAL)?;
1948 result.push(sym as u8);
1949 self.observe_symbol_from_current_pdf(sym as u8)?;
1950 }
1951
1952 self.finish_online_policy_stream()?;
1953 Ok(result)
1954 }
1955
1956 fn decompress_rans(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
1958 if compressed.len() < 4 {
1960 bail!("rANS data too short");
1961 }
1962 let block_count =
1963 u32::from_le_bytes([compressed[0], compressed[1], compressed[2], compressed[3]])
1964 as usize;
1965
1966 let mut blocks = Vec::with_capacity(block_count);
1968 let mut pos = 4;
1969
1970 for _ in 0..block_count {
1971 if pos + 4 > compressed.len() {
1972 bail!("Truncated block header");
1973 }
1974 let block_len = u32::from_le_bytes([
1975 compressed[pos],
1976 compressed[pos + 1],
1977 compressed[pos + 2],
1978 compressed[pos + 3],
1979 ]) as usize;
1980 pos += 4;
1981
1982 if pos + block_len > compressed.len() {
1983 bail!("Truncated block data");
1984 }
1985 blocks.push(&compressed[pos..pos + block_len]);
1986 pos += block_len;
1987 }
1988
1989 let mut decoder = BlockedRansDecoder::new(blocks, original_len)?;
1991 let mut result = Vec::with_capacity(original_len);
1992
1993 self.refresh_current_pdf(0);
1995
1996 for _ in 0..original_len {
1997 quantize_pdf_to_rans_cdf_with_buffer(
1998 &self.pdf_buffer,
1999 &mut self.cdf_buffer_rans,
2000 &mut self.rans_freq_buffer,
2001 );
2002 let sym = decoder.decode(&self.cdf_buffer_rans)?;
2003 result.push(sym as u8);
2004 self.observe_symbol_from_current_pdf(sym as u8)?;
2005 }
2006
2007 self.finish_online_policy_stream()?;
2008 Ok(result)
2009 }
2010
2011 pub fn cross_entropy(&mut self, data: &[u8]) -> Result<f64> {
2022 self.finish_online_policy_stream()?;
2023 self.reset_and_prime();
2024 self.cross_entropy_from_current(data)
2025 }
2026
2027 pub fn cross_entropy_conditional_chain(
2029 &mut self,
2030 prefix_parts: &[&[u8]],
2031 data: &[u8],
2032 ) -> Result<f64> {
2033 if data.is_empty() {
2034 return Ok(0.0);
2035 }
2036 let prefix_len = prefix_parts
2037 .iter()
2038 .fold(0usize, |acc, p| acc.saturating_add(p.len()));
2039 self.finish_online_policy_stream()?;
2040 self.reset_and_prime();
2041 self.fit_chain(prefix_parts, Some((prefix_len + data.len()) as u64))?;
2042
2043 let mut total_bits = 0.0f64;
2044 for &byte in data {
2045 total_bits -= self.pdf_buffer[byte as usize].log2();
2046 self.observe_symbol_from_current_pdf(byte)?;
2047 }
2048
2049 self.finish_online_policy_stream()?;
2050 Ok(total_bits / (data.len() as f64))
2051 }
2052
2053 pub fn cross_entropy_conditional(&mut self, prefix: &[u8], data: &[u8]) -> Result<f64> {
2055 if data.is_empty() {
2056 return Ok(0.0);
2057 }
2058
2059 self.finish_online_policy_stream()?;
2060 self.reset_and_prime();
2061 self.begin_online_policy_stream(Some((prefix.len() + data.len()) as u64))?;
2062
2063 for &byte in prefix {
2065 self.observe_symbol_from_current_pdf(byte)?;
2066 }
2067
2068 let mut total_bits = 0.0f64;
2069 for &byte in data {
2070 total_bits -= self.pdf_buffer[byte as usize].log2();
2071 self.observe_symbol_from_current_pdf(byte)?;
2072 }
2073
2074 self.finish_online_policy_stream()?;
2075 Ok(total_bits / (data.len() as f64))
2076 }
2077
2078 pub fn joint_cross_entropy_aligned_min(&mut self, x: &[u8], y: &[u8]) -> Result<f64> {
2080 let n = x.len().min(y.len());
2081 if n == 0 {
2082 return Ok(0.0);
2083 }
2084
2085 let h_xy = self.joint_cross_entropy_aligned_order(x, y, false)?;
2086 let h_yx = self.joint_cross_entropy_aligned_order(x, y, true)?;
2087 Ok(h_xy.min(h_yx))
2088 }
2089
2090 fn joint_cross_entropy_aligned_order(&mut self, x: &[u8], y: &[u8], swap: bool) -> Result<f64> {
2091 let n = x.len().min(y.len());
2092 if n == 0 {
2093 return Ok(0.0);
2094 }
2095
2096 self.restart_online_policy_stream(Some((2 * n) as u64))?;
2097
2098 self.refresh_current_pdf(0);
2099
2100 let mut total_bits = 0.0f64;
2101 for i in 0..n {
2102 let a = if swap { y[i] } else { x[i] };
2103 let b = if swap { x[i] } else { y[i] };
2104
2105 let pa = self.pdf_buffer[a as usize];
2106 total_bits -= pa.log2();
2107 self.observe_symbol_from_current_pdf(a)?;
2108
2109 let pb = self.pdf_buffer[b as usize];
2110 total_bits -= pb.log2();
2111 self.observe_symbol_from_current_pdf(b)?;
2112 }
2113
2114 self.finish_online_policy_stream()?;
2115 Ok(total_bits / (n as f64))
2116 }
2117}
2118
2119#[derive(Debug, Clone)]
2125pub struct CompressionStats {
2126 pub original_size: usize,
2128 pub compressed_size: usize,
2130 pub ratio: f64,
2132 pub bits_per_byte: f64,
2134 pub time_seconds: f64,
2136 pub throughput: f64,
2138}
2139
2140impl std::fmt::Display for CompressionStats {
2141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2142 write!(
2143 f,
2144 "{} bytes -> {} bytes | ratio={:.3} | bits/byte={:.3} | time={:.2}s | {:.0} B/s",
2145 self.original_size,
2146 self.compressed_size,
2147 self.ratio,
2148 self.bits_per_byte,
2149 self.time_seconds,
2150 self.throughput,
2151 )
2152 }
2153}
2154
2155pub fn compress_with_stats(
2159 compressor: &mut Compressor,
2160 data: &[u8],
2161 coder: CoderType,
2162) -> Result<(Vec<u8>, CompressionStats)> {
2163 let start = std::time::Instant::now();
2164 let compressed = compressor.compress(data, coder)?;
2165 let elapsed = start.elapsed().as_secs_f64();
2166
2167 let stats = CompressionStats {
2168 original_size: data.len(),
2169 compressed_size: compressed.len(),
2170 ratio: data.len() as f64 / compressed.len() as f64,
2171 bits_per_byte: (compressed.len() as f64 * 8.0) / data.len() as f64,
2172 time_seconds: elapsed,
2173 throughput: data.len() as f64 / elapsed,
2174 };
2175
2176 Ok((compressed, stats))
2177}
2178
2179#[cfg(test)]
2184mod tests {
2185 use super::*;
2186 use std::time::{SystemTime, UNIX_EPOCH};
2187
2188 fn temp_path(name: &str, ext: &str) -> PathBuf {
2189 let ts = SystemTime::now()
2190 .duration_since(UNIX_EPOCH)
2191 .unwrap()
2192 .as_nanos();
2193 std::env::temp_dir().join(format!("infotheory_rwkvzip_{name}_{ts}.{ext}"))
2194 }
2195
2196 #[test]
2197 fn test_header_roundtrip() {
2198 let header = Header::new(CoderType::AC, 12345, 0xDEADBEEF);
2199
2200 let mut buf = Vec::new();
2201 header.write(&mut buf).unwrap();
2202
2203 assert_eq!(buf.len(), Header::SIZE);
2204
2205 let mut cursor = Cursor::new(&buf);
2206 let read_header = Header::read(&mut cursor).unwrap();
2207
2208 assert_eq!(read_header.magic, MAGIC);
2209 assert_eq!(read_header.version, VERSION);
2210 assert_eq!(read_header.coder, 0);
2211 assert_eq!(read_header.original_len, 12345);
2212 assert_eq!(read_header.crc32, 0xDEADBEEF);
2213 }
2214
2215 #[test]
2216 fn test_header_rans() {
2217 let header = Header::new(CoderType::RANS, 67890, 0xCAFEBABE);
2218 assert_eq!(header.coder, 1);
2219 assert_eq!(header.coder_type(), CoderType::RANS);
2220 }
2221
2222 #[test]
2223 fn test_coder_type_display() {
2224 assert_eq!(format!("{}", CoderType::AC), "AC");
2225 assert_eq!(format!("{}", CoderType::RANS), "rANS");
2226 }
2227
2228 #[test]
2229 fn test_crc32() {
2230 let data = b"Hello, World!";
2231 let c = crc32(data);
2232 assert_ne!(c, 0);
2233 assert_eq!(c, crc32(data));
2235 }
2236
2237 #[test]
2238 fn test_crc32_different_data() {
2239 let c1 = crc32(b"Hello");
2240 let c2 = crc32(b"World");
2241 assert_ne!(c1, c2);
2242 }
2243
2244 #[test]
2245 fn test_crc32_known_vector() {
2246 assert_eq!(crc32(b"123456789"), 0xCBF4_3926);
2248 }
2249
2250 #[test]
2251 fn test_header_rejects_invalid_magic() {
2252 let mut buf = Vec::new();
2253 let header = Header::new(CoderType::AC, 1, 2);
2254 header.write(&mut buf).unwrap();
2255 buf[0] ^= 0xFF;
2257
2258 let mut cursor = Cursor::new(&buf);
2259 let err = Header::read(&mut cursor).unwrap_err();
2260 let msg = format!("{err:#}");
2261 assert!(msg.contains("Invalid magic number"));
2262 }
2263
2264 #[test]
2265 fn test_parse_method_spec_file_and_cfg() {
2266 let p = temp_path("dummy", "bin");
2267 std::fs::write(&p, b"x").unwrap();
2268
2269 match parse_method_spec(&format!("file:{}", p.display())).unwrap() {
2270 MethodSpec::File { path: got, .. } => assert_eq!(got, p),
2271 _ => panic!("expected file method"),
2272 }
2273
2274 match parse_method_spec(
2275 "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=1,train=none,lr=0.01,stride=2;policy:schedule=0..100:infer",
2276 )
2277 .unwrap()
2278 {
2279 MethodSpec::Online { cfg, .. } => {
2280 assert_eq!(cfg.hidden, 64);
2281 assert_eq!(cfg.layers, 1);
2282 assert_eq!(cfg.seed, 1);
2283 assert_eq!(cfg.stride, 2);
2284 }
2285 _ => panic!("expected cfg method"),
2286 }
2287
2288 match parse_method_spec("64,64,1,0,7,0.01,2;policy:schedule=0..100:infer").unwrap() {
2289 MethodSpec::Online { cfg, .. } => {
2290 assert_eq!(cfg.hidden, 64);
2291 assert_eq!(cfg.intermediate, 64);
2292 assert_eq!(cfg.layers, 1);
2293 assert_eq!(cfg.seed, 7);
2294 assert_eq!(cfg.stride, 2);
2295 }
2296 _ => panic!("expected positional cfg method"),
2297 }
2298
2299 match parse_method_spec(&p.display().to_string()).unwrap() {
2301 MethodSpec::File { path: got, .. } => assert_eq!(got, p),
2302 _ => panic!("expected file method"),
2303 }
2304
2305 std::fs::remove_file(&p).ok();
2306 }
2307
2308 #[test]
2309 fn test_parse_method_spec_rejects_unknown_cfg_key() {
2310 let err =
2311 parse_method_spec("cfg:hidden=64,wat=1;policy:schedule=0..100:infer").unwrap_err();
2312 assert!(format!("{err:#}").contains("unknown rwkv cfg key"));
2313 }
2314
2315 #[test]
2316 fn test_parse_method_spec_accepts_cfg_without_policy() {
2317 let spec = parse_method_spec("cfg:hidden=64,layers=1,intermediate=64").unwrap();
2318 match spec {
2319 MethodSpec::Online { cfg, policy } => {
2320 assert_eq!(cfg.hidden, 64);
2321 assert_eq!(cfg.layers, 1);
2322 assert_eq!(cfg.intermediate, 64);
2323 assert!(policy.is_none());
2324 }
2325 _ => panic!("expected cfg method"),
2326 }
2327 }
2328
2329 #[test]
2330 fn test_canonical_method_omits_policy_when_absent() {
2331 let c = Compressor::new_from_method("cfg:hidden=64,layers=1,intermediate=64").unwrap();
2332 assert_eq!(
2333 c.online_method_string(),
2334 Some(
2335 "cfg:hidden=64,layers=1,intermediate=64,decay_rank=32,a_rank=32,v_rank=32,g_rank=64,seed=0,train=none,lr=0.001,stride=1"
2336 )
2337 );
2338 }
2339
2340 #[test]
2341 fn test_online_export_reload_roundtrip() {
2342 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=7,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.9)";
2343 let data = b"rwkv online export/load deterministic sample";
2344
2345 let mut c1 = Compressor::new_from_method(method).unwrap();
2346 let _ = c1.compress(data, CoderType::AC).unwrap();
2347
2348 let model_path = temp_path("export", "safetensors");
2349 c1.export_online(&model_path).unwrap();
2350 let out1_after_export = c1.compress(data, CoderType::AC).unwrap();
2351
2352 let mut c2 = Compressor::new(&model_path).unwrap();
2353 let out2 = c2.compress(data, CoderType::AC).unwrap();
2354
2355 assert_eq!(out1_after_export, out2);
2356 assert!(model_path.with_extension("json").exists());
2357
2358 std::fs::remove_file(&model_path).ok();
2359 std::fs::remove_file(model_path.with_extension("json")).ok();
2360 }
2361
2362 #[test]
2363 fn test_runtime_snapshot_restores_online_state() {
2364 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=9,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.9)";
2365 let mut c = Compressor::new_from_method(method).unwrap();
2366 c.reset_and_prime();
2367 c.absorb_chain(&[b"prior context".as_slice()]).unwrap();
2368 let snap = c.snapshot_runtime();
2369
2370 c.absorb_chain(&[b"snippet-a".as_slice()]).unwrap();
2371 let score_a = c.cross_entropy_from_current(b"query").unwrap();
2372
2373 c.restore_runtime(&snap);
2374 c.absorb_chain(&[b"snippet-b".as_slice()]).unwrap();
2375 let score_b = c.cross_entropy_from_current(b"query").unwrap();
2376
2377 c.restore_runtime(&snap);
2378 c.absorb_chain(&[b"snippet-b".as_slice()]).unwrap();
2379 let score_b_again = c.cross_entropy_from_current(b"query").unwrap();
2380
2381 assert!((score_b - score_b_again).abs() < 1e-12);
2382 let _ = score_a;
2383 }
2384
2385 #[test]
2386 fn test_runtime_snapshot_restores_non_head_training_state() {
2387 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=15,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2388 let mut c = Compressor::new_from_method(method).unwrap();
2389 c.reset_and_prime();
2390 c.absorb_chain(&[b"prior context".as_slice()]).unwrap();
2391 let snap = c.snapshot_runtime();
2392
2393 let _ = c
2394 .cross_entropy_from_current(b"mutate model before restore")
2395 .unwrap();
2396
2397 c.restore_runtime(&snap);
2398 let score_a = c
2399 .cross_entropy_from_current(b"query after restore")
2400 .unwrap();
2401
2402 c.restore_runtime(&snap);
2403 let score_b = c
2404 .cross_entropy_from_current(b"query after restore")
2405 .unwrap();
2406
2407 assert!((score_a - score_b).abs() < 1e-12);
2408 }
2409
2410 #[test]
2411 fn test_online_training_updates_lm_head_weights() {
2412 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=5,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.9)";
2413 let mut c = Compressor::new_from_method(method).unwrap();
2414 c.reset_and_prime();
2415 let before = c.model.lm_head_weights()[0..64].to_vec();
2416 let _ = c
2417 .cross_entropy_from_current(b"online rwkv weight update")
2418 .unwrap();
2419 let after = &c.model.lm_head_weights()[0..64];
2420 let mut changed = false;
2421 for i in 0..before.len() {
2422 if before[i].to_bits() != after[i].to_bits() {
2423 changed = true;
2424 break;
2425 }
2426 }
2427 assert!(
2428 changed,
2429 "expected LM-head weights to change under online training"
2430 );
2431 }
2432
2433 #[test]
2434 fn test_cross_entropy_from_current_keeps_unique_model_arc() {
2435 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=21,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2436 let mut c = Compressor::new_from_method(method).unwrap();
2437 c.reset_and_prime();
2438
2439 assert_eq!(Arc::strong_count(&c.model), 1);
2440 let before = Arc::as_ptr(&c.model);
2441 let _ = c.cross_entropy_from_current(b"arc uniqueness").unwrap();
2442 let after = Arc::as_ptr(&c.model);
2443
2444 assert_eq!(Arc::strong_count(&c.model), 1);
2445 assert_eq!(before, after);
2446 }
2447
2448 #[test]
2449 fn test_online_training_non_head_scope_updates_model_params() {
2450 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=13,train=sgd,lr=0.005,stride=1;policy:schedule=0..100:train(scope=attn,opt=sgd,lr=0.005,stride=1,bptt=1,clip=0,momentum=0.9)";
2451 let mut c = Compressor::new_from_method(method).unwrap();
2452
2453 let head_before = c.model.lm_head_weights()[0..64].to_vec();
2454 let before_path = temp_path("rwkv_non_head_before", "safetensors");
2455 let after_path = temp_path("rwkv_non_head_after", "safetensors");
2456 c.model.save_safetensors(&before_path).unwrap();
2457
2458 c.reset_and_prime();
2459 let _ = c
2460 .cross_entropy_from_current(b"rwkv non head online update")
2461 .unwrap();
2462 c.model.save_safetensors(&after_path).unwrap();
2463
2464 let head_after = &c.model.lm_head_weights()[0..64];
2465 for idx in 0..head_before.len() {
2466 assert_eq!(
2467 head_before[idx].to_bits(),
2468 head_after[idx].to_bits(),
2469 "lm-head changed under scope=attn at index {idx}"
2470 );
2471 }
2472
2473 let before_bytes = std::fs::read(&before_path).unwrap();
2474 let after_bytes = std::fs::read(&after_path).unwrap();
2475 assert_ne!(
2476 before_bytes, after_bytes,
2477 "expected non-head params to change"
2478 );
2479
2480 std::fs::remove_file(&before_path).ok();
2481 std::fs::remove_file(&after_path).ok();
2482 }
2483
2484 #[test]
2485 fn test_online_training_scope_all_bptt_gt_one_supported() {
2486 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=23,train=adam,lr=0.001,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.001,stride=1,bptt=2,clip=0,momentum=0.9)";
2487 let mut c = Compressor::new_from_method(method).unwrap();
2488 let before_path = temp_path("rwkv_tbptt_before", "safetensors");
2489 let after_path = temp_path("rwkv_tbptt_after", "safetensors");
2490 c.model.save_safetensors(&before_path).unwrap();
2491 c.reset_and_prime();
2492 let score = c.cross_entropy_from_current(b"abcdef").unwrap();
2493 assert!(score.is_finite());
2494 c.model.save_safetensors(&after_path).unwrap();
2495 let before_bytes = std::fs::read(&before_path).unwrap();
2496 let after_bytes = std::fs::read(&after_path).unwrap();
2497 assert_ne!(
2498 before_bytes, after_bytes,
2499 "expected tbptt training to update params"
2500 );
2501 std::fs::remove_file(&before_path).ok();
2502 std::fs::remove_file(&after_path).ok();
2503 }
2504
2505 #[test]
2506 fn test_online_training_scope_all_bptt_one_uses_fast_default_window() {
2507 let method_bptt1 = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=27,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2508 let method_bptt8 = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=27,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=8,clip=0,momentum=0.9)";
2509 let data = b"abcdefghij";
2510
2511 let mut c1 = Compressor::new_from_method(method_bptt1).unwrap();
2512 let mut c8 = Compressor::new_from_method(method_bptt8).unwrap();
2513
2514 let score1 = c1.cross_entropy(data).unwrap();
2515 let score8 = c8.cross_entropy(data).unwrap();
2516 assert!((score1 - score8).abs() < 1e-12);
2517
2518 let bptt1_path = temp_path("rwkv_bptt1_fast_default", "safetensors");
2519 let bptt8_path = temp_path("rwkv_bptt8_fast_default", "safetensors");
2520 c1.model.save_safetensors(&bptt1_path).unwrap();
2521 c8.model.save_safetensors(&bptt8_path).unwrap();
2522 let bptt1_bytes = std::fs::read(&bptt1_path).unwrap();
2523 let bptt8_bytes = std::fs::read(&bptt8_path).unwrap();
2524 assert_eq!(bptt1_bytes, bptt8_bytes);
2525 std::fs::remove_file(&bptt1_path).ok();
2526 std::fs::remove_file(&bptt8_path).ok();
2527 }
2528
2529 #[test]
2530 fn test_online_training_full_tbptt_updates_first_symbol_after_priming() {
2531 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=33,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=8,clip=0,momentum=0.9)";
2532 let mut c = Compressor::new_from_method(method).unwrap();
2533 let before_path = temp_path("rwkv_first_symbol_before", "safetensors");
2534 let after_path = temp_path("rwkv_first_symbol_after", "safetensors");
2535 c.model.save_safetensors(&before_path).unwrap();
2536
2537 c.reset_and_prime();
2538 let score = c.cross_entropy_from_current(b"a").unwrap();
2539 assert!(score.is_finite());
2540 c.model.save_safetensors(&after_path).unwrap();
2541
2542 let before_bytes = std::fs::read(&before_path).unwrap();
2543 let after_bytes = std::fs::read(&after_path).unwrap();
2544 assert_ne!(
2545 before_bytes, after_bytes,
2546 "expected the first symbol after priming to update params"
2547 );
2548 std::fs::remove_file(&before_path).ok();
2549 std::fs::remove_file(&after_path).ok();
2550 }
2551
2552 #[test]
2553 fn test_online_export_reload_roundtrip_preserves_full_adam_resume() {
2554 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2555 let data = b"rwkv full-adam export/load deterministic continuation sample";
2556
2557 let mut c1 = Compressor::new_from_method(method).unwrap();
2558 let _ = c1.compress(data, CoderType::AC).unwrap();
2559
2560 let model_path = temp_path("rwkv_full_adam_export", "safetensors");
2561 let opt_path = optimizer_sidecar_path(&model_path);
2562 c1.export_online(&model_path).unwrap();
2563 assert!(
2564 opt_path.exists(),
2565 "expected optimizer sidecar to be exported"
2566 );
2567 let out1_after_export = c1.compress(data, CoderType::AC).unwrap();
2568
2569 let mut c2 = Compressor::new(&model_path).unwrap();
2570 let out2 = c2.compress(data, CoderType::AC).unwrap();
2571 assert_eq!(out1_after_export, out2);
2572
2573 std::fs::remove_file(&model_path).ok();
2574 std::fs::remove_file(model_path.with_extension("json")).ok();
2575 std::fs::remove_file(&opt_path).ok();
2576 }
2577
2578 #[test]
2579 fn test_online_export_reload_missing_full_adam_sidecar_fails() {
2580 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=41,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2581 let mut c = Compressor::new_from_method(method).unwrap();
2582 let _ = c
2583 .compress(b"rwkv strict optimizer-sidecar requirement", CoderType::AC)
2584 .unwrap();
2585
2586 let model_path = temp_path("rwkv_full_adam_missing_sidecar", "safetensors");
2587 let opt_path = optimizer_sidecar_path(&model_path);
2588 c.export_online(&model_path).unwrap();
2589 std::fs::remove_file(&opt_path).unwrap();
2590
2591 let err = match Compressor::new(&model_path) {
2592 Ok(_) => panic!("expected missing optimizer sidecar to fail"),
2593 Err(err) => err,
2594 };
2595 assert!(format!("{err:#}").contains("missing optimizer sidecar"));
2596
2597 std::fs::remove_file(&model_path).ok();
2598 std::fs::remove_file(model_path.with_extension("json")).ok();
2599 }
2600
2601 #[test]
2602 fn test_clone_preserves_non_head_training_trace() {
2603 let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=43,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2604 let mut c = Compressor::new_from_method(method).unwrap();
2605 c.reset_and_prime();
2606 c.absorb_chain(&[b"clone trace prefix".as_slice()]).unwrap();
2607
2608 let mut cloned = c.clone();
2609 let score = cloned
2610 .cross_entropy_from_current(b"clone trace query")
2611 .unwrap();
2612 assert!(score.is_finite());
2613 }
2614}