infotheory/backends/rwkvzip/
mod.rs

1// rwkvzip - High-performance neural network compressor using RWKV7.
2//
3// This library provides lossless compression by leveraging the RWKV7 language model's
4// predictive capabilities to generate probability distributions, which are then
5// compressed via entropy coding (arithmetic coding or rANS).
6//
7// # Architecture
8//
9// - **Byte-level compression**: Operates directly on raw bytes (vocab_size=256)
10// - **Infinite context**: RWKV7's recurrent architecture maintains state indefinitely
11// - **Portable SIMD optimized**: `wide`-based kernels with ISA-specific codegen
12// - **Correct-by-construction**: Information-theoretically sound implementation
13
14use anyhow::{Context, Result, bail};
15use serde_json::json;
16use std::fs;
17use std::io::{Cursor, Read, Write};
18use std::path::{Path, PathBuf};
19use std::sync::Arc;
20
21use crate::backends::llm_policy::{
22    self, LlmPolicy, OptimizerKind, PolicyAction, PolicyRuntime, split_method_policy_segments,
23};
24/// RWKV7 model core (weights/state/kernels).
25pub mod rwkv7;
26/// Shared entropy coders used by rwkvzip containers.
27pub use crate::coders;
28/// Backward-compatible re-export of generic coder selection enum.
29pub use crate::coders::CoderType;
30
31use crate::coders::{
32    ANS_TOTAL, ArithmeticDecoder, ArithmeticEncoder, BlockedRansDecoder, BlockedRansEncoder,
33    CDF_TOTAL, Cdf, quantize_pdf_to_cdf_with_buffer, quantize_pdf_to_rans_cdf_with_buffer,
34};
35
36/// RWKV7 model configuration type.
37pub use rwkv7::Config;
38/// RWKV7 model type.
39pub use rwkv7::Model;
40/// RWKV7 temporary scratch buffers used during forward passes.
41pub use rwkv7::ScratchBuffers;
42/// RWKV7 recurrent state container.
43pub use rwkv7::State;
44
45// =============================================================================
46// File Format Constants
47// =============================================================================
48
49/// File format magic number: "GPTZ" in little-endian (0x47505A54 as ASCII).
50/// Used to identify valid rwkvzip compressed files.
51pub const MAGIC: u32 = 0x5a505447;
52
53/// File format version. Increment on breaking changes to ensure compatibility.
54pub const VERSION: u8 = 2;
55
56/// Vocabulary size for byte-level compression.
57/// Each byte (0-255) is treated as a separate symbol.
58pub const VOCAB_SIZE: usize = 256;
59/// Fast default full-parameter TBPTT window used when policy requests `bptt<=1`.
60const DEFAULT_FULL_TBPTT_WINDOW: usize = 8;
61const TBPTT_REPLAY_CHUNK: usize = 32;
62fn optimizer_sidecar_path(model_path: &Path) -> PathBuf {
63    model_path.with_extension("opt.safetensors")
64}
65const RWKV_TRAIN_SCOPES: &[&str] = &[
66    "embed",
67    "pre_norm",
68    "attn_norm",
69    "ffn_norm",
70    "attn",
71    "ffn",
72    "head",
73    "bias",
74    "all",
75    "none",
76];
77
78struct CountingWriter {
79    n: u64,
80}
81
82impl CountingWriter {
83    #[inline]
84    fn new() -> Self {
85        Self { n: 0 }
86    }
87
88    #[inline]
89    fn bytes_written(&self) -> u64 {
90        self.n
91    }
92}
93
94impl Write for CountingWriter {
95    #[inline]
96    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
97        let n = buf.len();
98        self.n = self.n.saturating_add(n as u64);
99        Ok(n)
100    }
101
102    #[inline]
103    fn flush(&mut self) -> std::io::Result<()> {
104        Ok(())
105    }
106}
107
108#[derive(Clone, Copy, Debug, PartialEq, Eq)]
109/// Online adaptation mode for RWKV output-bias updates.
110pub enum OnlineTrainMode {
111    /// Disable online updates.
112    None,
113    /// SGD updates on output bias.
114    Sgd,
115    /// Adam updates on output bias.
116    Adam,
117}
118
119#[derive(Clone, Debug)]
120/// Configuration for online RWKV model instantiation and adaptation.
121pub struct OnlineConfig {
122    /// Hidden size (must be multiple of 64 after clamping).
123    pub hidden: usize,
124    /// Number of recurrent layers.
125    pub layers: usize,
126    /// Feed-forward intermediate size.
127    pub intermediate: usize,
128    /// Low-rank dimension for decay projection.
129    pub decay_rank: usize,
130    /// Low-rank dimension for key-like projection.
131    pub a_rank: usize,
132    /// Low-rank dimension for value-like projection.
133    pub v_rank: usize,
134    /// Low-rank dimension for gate projection.
135    pub g_rank: usize,
136    /// Random seed used for online-random model initialization.
137    pub seed: u64,
138    /// Online training mode.
139    pub train_mode: OnlineTrainMode,
140    /// Learning rate for online adaptation.
141    pub lr: f32,
142    /// Update stride (apply update every `stride` tokens).
143    pub stride: usize,
144}
145
146impl Default for OnlineConfig {
147    fn default() -> Self {
148        Self {
149            hidden: 256,
150            layers: 6,
151            intermediate: 1024,
152            decay_rank: 32,
153            a_rank: 32,
154            v_rank: 32,
155            g_rank: 64,
156            seed: 0,
157            train_mode: OnlineTrainMode::None,
158            lr: 0.001,
159            stride: 1,
160        }
161    }
162}
163
164impl OnlineConfig {
165    /// Convert to a validated RWKV model configuration.
166    pub fn to_rwkv_config(&self) -> Result<Config> {
167        let hidden = self.hidden.max(64);
168        if !hidden.is_multiple_of(64) {
169            bail!("rwkv hidden must be a multiple of 64 (got {hidden})");
170        }
171        let num_heads = hidden / 64;
172        let cfg = Config {
173            vocab_size: 256,
174            hidden_size: hidden,
175            num_layers: self.layers.max(1),
176            num_heads,
177            head_dim: 64,
178            intermediate_size: self.intermediate.max(1),
179            layer_norm_eps: 1e-5,
180            group_norm_eps: 64e-5,
181            decay_low_rank: self.decay_rank.max(1),
182            a_low_rank: self.a_rank.max(1),
183            v_low_rank: self.v_rank.max(1),
184            g_low_rank: self.g_rank.max(1),
185        };
186        cfg.validate()?;
187        Ok(cfg)
188    }
189}
190
191#[derive(Clone, Debug)]
192/// Parsed RWKV method specification.
193pub enum MethodSpec {
194    /// Load a model from disk.
195    File {
196        /// Path to `.safetensors` model weights.
197        path: PathBuf,
198        /// Optional runtime training/inference policy.
199        policy: Option<LlmPolicy>,
200    },
201    /// Build an online/random model from configuration.
202    Online {
203        /// Online model/training configuration.
204        cfg: OnlineConfig,
205        /// Optional runtime training/inference policy.
206        policy: Option<LlmPolicy>,
207    },
208}
209
210#[derive(Clone)]
211struct OnlineRuntime {
212    cfg: OnlineConfig,
213    canonical_method: String,
214    policy: Option<LlmPolicy>,
215    policy_runtime: Option<PolicyRuntime>,
216    needs_full_trace: bool,
217    policy_stream_total: Option<u64>,
218    policy_train_steps: u64,
219    tokens_processed: u64,
220    out_bias: Vec<f32>,
221    adam_m: Option<Vec<f32>>,
222    adam_v: Option<Vec<f32>>,
223    full_adam: Option<rwkv7::FullAdamState>,
224    lm_head_adam_m: Option<Vec<f32>>,
225    lm_head_adam_v: Option<Vec<f32>>,
226    adam_t: usize,
227    full_tbptt: Option<FullTbpttRuntime>,
228}
229
230#[derive(Clone, Copy, Debug)]
231struct FullTrainSettings {
232    optimizer: OptimizerKind,
233    lr: f32,
234    scope: rwkv7::TrainScopeMask,
235    bptt: usize,
236    clip: f32,
237}
238
239impl FullTrainSettings {
240    fn matches(
241        self,
242        optimizer: OptimizerKind,
243        lr: f32,
244        scope: rwkv7::TrainScopeMask,
245        bptt: usize,
246        clip: f32,
247    ) -> bool {
248        self.optimizer == optimizer
249            && self.lr.to_bits() == lr.to_bits()
250            && self.scope == scope
251            && self.bptt == bptt
252            && self.clip.to_bits() == clip.to_bits()
253    }
254}
255
256#[derive(Clone)]
257struct FullTbpttRuntime {
258    pending_input_token: Option<u32>,
259    pending_input_pre_state: Option<State>,
260    segment_start_state: Option<State>,
261    steps: Vec<(u32, u8)>,
262    settings: Option<FullTrainSettings>,
263}
264
265#[derive(Clone)]
266/// Snapshot of mutable runtime state used for reversible scoring.
267pub struct RuntimeSnapshot {
268    model: Arc<Model>,
269    scratch: ScratchBuffers,
270    state: State,
271    pdf_buffer: Vec<f64>,
272    online: Option<OnlineRuntime>,
273}
274
275impl OnlineRuntime {
276    fn new(
277        cfg: OnlineConfig,
278        canonical_method: String,
279        policy: Option<LlmPolicy>,
280        vocab_size: usize,
281        hidden_size: usize,
282    ) -> Self {
283        let mut use_adam = matches!(cfg.train_mode, OnlineTrainMode::Adam);
284        if let Some(pol) = &policy {
285            use_adam = policy_uses_adam(pol) || use_adam;
286        }
287        let needs_full_trace = policy
288            .as_ref()
289            .map(policy_needs_full_trace)
290            .unwrap_or(false);
291        Self {
292            canonical_method,
293            cfg,
294            policy,
295            policy_runtime: None,
296            needs_full_trace,
297            policy_stream_total: None,
298            policy_train_steps: 0,
299            tokens_processed: 0,
300            out_bias: vec![0.0; vocab_size],
301            adam_m: use_adam.then(|| vec![0.0; vocab_size]),
302            adam_v: use_adam.then(|| vec![0.0; vocab_size]),
303            full_adam: None,
304            lm_head_adam_m: use_adam.then(|| vec![0.0; vocab_size * hidden_size]),
305            lm_head_adam_v: use_adam.then(|| vec![0.0; vocab_size * hidden_size]),
306            adam_t: 0,
307            full_tbptt: needs_full_trace.then(|| FullTbpttRuntime {
308                pending_input_token: None,
309                pending_input_pre_state: None,
310                segment_start_state: None,
311                steps: Vec::new(),
312                settings: None,
313            }),
314        }
315    }
316
317    fn prepare_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
318        self.policy_stream_total = total_symbols;
319        self.policy_train_steps = 0;
320        if let Some(tbptt) = self.full_tbptt.as_mut() {
321            // Preserve the current predictive edge so the first symbol of the
322            // new stream still trains against the already-primed distribution.
323            tbptt.segment_start_state = None;
324            tbptt.steps.clear();
325            tbptt.settings = None;
326        }
327        self.policy_runtime = match &self.policy {
328            Some(p) => Some(PolicyRuntime::new(p.compile(total_symbols)?)),
329            None => None,
330        };
331        Ok(())
332    }
333
334    #[inline]
335    fn next_policy_action(&mut self) -> Result<Option<PolicyAction>> {
336        if self.policy.is_none() {
337            return Ok(None);
338        }
339        if self.policy_runtime.is_none() {
340            self.prepare_policy_stream(None)?;
341        }
342        Ok(self.policy_runtime.as_mut().map(PolicyRuntime::next_action))
343    }
344}
345
346#[allow(clippy::needless_range_loop, clippy::too_many_arguments)]
347fn apply_online_lm_head_update(
348    model: &mut Model,
349    online: &mut OnlineRuntime,
350    hidden: &[f32],
351    symbol: u8,
352    pdf: &[f64],
353    lr: f32,
354    optimizer: OptimizerKind,
355    train_head: bool,
356    train_bias: bool,
357    clip: f32,
358) {
359    let h = hidden.len();
360    if h == 0 {
361        return;
362    }
363
364    let head = model.lm_head_weights_mut();
365    let vocab_rows = head.len() / h;
366    let n = online.out_bias.len().min(pdf.len()).min(vocab_rows);
367
368    match optimizer {
369        OptimizerKind::Sgd => {
370            for (i, p_raw) in pdf.iter().enumerate().take(n) {
371                let p = (*p_raw).clamp(1e-12, 1.0) as f32;
372                let target = if i == symbol as usize { 1.0 } else { 0.0 };
373                let mut grad = target - p;
374                if clip > 0.0 {
375                    grad = grad.clamp(-clip, clip);
376                }
377                if train_bias {
378                    online.out_bias[i] += lr * grad;
379                }
380
381                if train_head {
382                    let row_off = i * h;
383                    for j in 0..h {
384                        head[row_off + j] += lr * grad * hidden[j];
385                    }
386                }
387            }
388        }
389        OptimizerKind::Adam => {
390            online.adam_t = online.adam_t.saturating_add(1);
391            let t = online.adam_t as i32;
392            let b1 = 0.9f32;
393            let b2 = 0.999f32;
394            let eps = 1e-8f32;
395            let bias_corr1 = 1.0 - b1.powi(t);
396            let bias_corr2 = 1.0 - b2.powi(t);
397            if online.adam_m.is_none() || online.adam_v.is_none() {
398                online.adam_m = Some(vec![0.0; online.out_bias.len()]);
399                online.adam_v = Some(vec![0.0; online.out_bias.len()]);
400            }
401            if online.lm_head_adam_m.is_none() || online.lm_head_adam_v.is_none() {
402                online.lm_head_adam_m = Some(vec![0.0; vocab_rows * h]);
403                online.lm_head_adam_v = Some(vec![0.0; vocab_rows * h]);
404            }
405            let bm = online.adam_m.as_mut().expect("adam_m initialized");
406            let bv = online.adam_v.as_mut().expect("adam_v initialized");
407            let hm = online
408                .lm_head_adam_m
409                .as_mut()
410                .expect("lm_head_adam_m initialized");
411            let hv = online
412                .lm_head_adam_v
413                .as_mut()
414                .expect("lm_head_adam_v initialized");
415            for i in 0..n {
416                let p = pdf[i].clamp(1e-12, 1.0) as f32;
417                let target = if i == symbol as usize { 1.0 } else { 0.0 };
418                let mut grad = target - p;
419                if clip > 0.0 {
420                    grad = grad.clamp(-clip, clip);
421                }
422
423                if train_bias {
424                    bm[i] = b1 * bm[i] + (1.0 - b1) * grad;
425                    bv[i] = b2 * bv[i] + (1.0 - b2) * grad * grad;
426                    let m_hat = bm[i] / bias_corr1;
427                    let v_hat = bv[i] / bias_corr2;
428                    online.out_bias[i] += lr * m_hat / (v_hat.sqrt() + eps);
429                }
430
431                if train_head {
432                    let row_off = i * h;
433                    for j in 0..h {
434                        let idx = row_off + j;
435                        let g = grad * hidden[j];
436                        hm[idx] = b1 * hm[idx] + (1.0 - b1) * g;
437                        hv[idx] = b2 * hv[idx] + (1.0 - b2) * g * g;
438                        let m_hat_w = hm[idx] / bias_corr1;
439                        let v_hat_w = hv[idx] / bias_corr2;
440                        head[idx] += lr * m_hat_w / (v_hat_w.sqrt() + eps);
441                    }
442                }
443            }
444        }
445    }
446}
447
448fn policy_uses_adam(policy: &LlmPolicy) -> bool {
449    use llm_policy::ScheduleRule;
450    for rule in &policy.schedule {
451        match rule {
452            ScheduleRule::Interval(interval) => {
453                if let PolicyAction::Train(train) = &interval.action
454                    && matches!(train.optimizer, OptimizerKind::Adam)
455                {
456                    return true;
457                }
458            }
459            ScheduleRule::Repeat(repeat) => {
460                for seg in &repeat.pattern {
461                    if let PolicyAction::Train(train) = &seg.action
462                        && matches!(train.optimizer, OptimizerKind::Adam)
463                    {
464                        return true;
465                    }
466                }
467            }
468        }
469    }
470    false
471}
472
473fn scope_needs_full_trace(scope: &llm_policy::TrainScopeSet) -> bool {
474    scope.all
475        || scope.contains("embed")
476        || scope.contains("pre_norm")
477        || scope.contains("attn_norm")
478        || scope.contains("ffn_norm")
479        || scope.contains("attn")
480        || scope.contains("ffn")
481}
482
483fn policy_needs_full_trace(policy: &LlmPolicy) -> bool {
484    use llm_policy::ScheduleRule;
485    for rule in &policy.schedule {
486        match rule {
487            ScheduleRule::Interval(interval) => {
488                if let PolicyAction::Train(train) = &interval.action
489                    && scope_needs_full_trace(&train.scope)
490                {
491                    return true;
492                }
493            }
494            ScheduleRule::Repeat(repeat) => {
495                for seg in &repeat.pattern {
496                    if let PolicyAction::Train(train) = &seg.action
497                        && scope_needs_full_trace(&train.scope)
498                    {
499                        return true;
500                    }
501                }
502            }
503        }
504    }
505    false
506}
507
508fn scope_from_train_action(train: &llm_policy::TrainAction) -> rwkv7::TrainScopeMask {
509    if train.scope.all {
510        return rwkv7::TrainScopeMask::all();
511    }
512    rwkv7::TrainScopeMask {
513        embed: train.scope.contains("embed"),
514        pre_norm: train.scope.contains("pre_norm"),
515        attn_norm: train.scope.contains("attn_norm"),
516        ffn_norm: train.scope.contains("ffn_norm"),
517        attn: train.scope.contains("attn"),
518        ffn: train.scope.contains("ffn"),
519        head: train.scope.contains("head"),
520        bias: train.scope.contains("bias"),
521    }
522}
523
524fn cfg_to_method_string(cfg: &OnlineConfig) -> String {
525    let train = match cfg.train_mode {
526        OnlineTrainMode::None => "none",
527        OnlineTrainMode::Sgd => "sgd",
528        OnlineTrainMode::Adam => "adam",
529    };
530    format!(
531        "cfg:hidden={},layers={},intermediate={},decay_rank={},a_rank={},v_rank={},g_rank={},seed={},train={},lr={},stride={}",
532        cfg.hidden,
533        cfg.layers,
534        cfg.intermediate,
535        cfg.decay_rank,
536        cfg.a_rank,
537        cfg.v_rank,
538        cfg.g_rank,
539        cfg.seed,
540        train,
541        cfg.lr,
542        cfg.stride.max(1),
543    )
544}
545
546fn softmax_pdf_floor_with_bias(logits: &[f32], bias: Option<&[f32]>, pdf_out: &mut [f64]) {
547    debug_assert_eq!(logits.len(), pdf_out.len());
548    if let Some(b) = bias {
549        debug_assert_eq!(b.len(), logits.len());
550    }
551    if logits.is_empty() {
552        return;
553    }
554
555    let mut max_logit = f32::NEG_INFINITY;
556    if let Some(b) = bias {
557        for i in 0..logits.len() {
558            let z = logits[i] + b[i];
559            if z > max_logit {
560                max_logit = z;
561            }
562        }
563    } else {
564        for &z in logits {
565            if z > max_logit {
566                max_logit = z;
567            }
568        }
569    }
570
571    let mut sum = 0.0f64;
572    if let Some(b) = bias {
573        for i in 0..logits.len() {
574            let p = ((logits[i] + b[i] - max_logit) as f64).exp();
575            pdf_out[i] = p;
576            sum += p;
577        }
578    } else {
579        for i in 0..logits.len() {
580            let p = ((logits[i] - max_logit) as f64).exp();
581            pdf_out[i] = p;
582            sum += p;
583        }
584    }
585
586    let inv_sum = if sum.is_finite() && sum > 0.0 {
587        1.0 / sum
588    } else {
589        1.0 / (logits.len() as f64)
590    };
591
592    let floor = 1e-12f64;
593    let mut norm = 0.0f64;
594    for p in pdf_out.iter_mut() {
595        *p = (*p * inv_sum).max(floor);
596        norm += *p;
597    }
598    let inv_norm = if norm.is_finite() && norm > 0.0 {
599        1.0 / norm
600    } else {
601        1.0 / (logits.len() as f64)
602    };
603    for p in pdf_out.iter_mut() {
604        *p *= inv_norm;
605    }
606}
607
608fn parse_u64(v: &str, key: &str) -> Result<u64> {
609    v.parse::<u64>()
610        .with_context(|| format!("invalid integer value for '{key}': {v}"))
611}
612
613fn parse_usize(v: &str, key: &str) -> Result<usize> {
614    v.parse::<usize>()
615        .with_context(|| format!("invalid integer value for '{key}': {v}"))
616}
617
618fn parse_f32(v: &str, key: &str) -> Result<f32> {
619    v.parse::<f32>()
620        .with_context(|| format!("invalid float value for '{key}': {v}"))
621}
622
623fn parse_train_mode_token(v: &str) -> Result<OnlineTrainMode> {
624    let code = v.trim().to_ascii_lowercase();
625    match code.as_str() {
626        "0" | "none" | "off" => Ok(OnlineTrainMode::None),
627        "1" | "sgd" => Ok(OnlineTrainMode::Sgd),
628        "2" | "adam" => Ok(OnlineTrainMode::Adam),
629        other => bail!("unknown train mode '{other}'"),
630    }
631}
632
633fn parse_cfg_positional(csv: &str) -> Result<OnlineConfig> {
634    let vals: Vec<&str> = csv
635        .split(',')
636        .map(|s| s.trim())
637        .filter(|s| !s.is_empty())
638        .collect();
639    if vals.len() != 6 && vals.len() != 7 {
640        bail!(
641            "positional cfg format expects 6 or 7 values: hidden,intermediate,layers,train,seed,lr[,stride]"
642        );
643    }
644
645    let cfg = OnlineConfig {
646        hidden: parse_usize(vals[0], "hidden")?,
647        intermediate: parse_usize(vals[1], "intermediate")?,
648        layers: parse_usize(vals[2], "layers")?,
649        train_mode: parse_train_mode_token(vals[3])?,
650        seed: parse_u64(vals[4], "seed")?,
651        lr: parse_f32(vals[5], "lr")?,
652        stride: if vals.len() == 7 {
653            parse_usize(vals[6], "stride")?
654        } else {
655            1
656        },
657        ..OnlineConfig::default()
658    };
659    Ok(cfg)
660}
661
662/// Parse a method string into a concrete RWKV method specification.
663///
664/// Supported formats:
665/// - `file:/path/to/model.safetensors`
666/// - `file:/path/to/model.safetensors;policy:...`
667/// - `cfg:key=value,...[;policy:...]`
668/// - positional `cfg` CSV
669/// - existing model path
670pub fn parse_method_spec(method: &str) -> Result<MethodSpec> {
671    let (base, policy_segment) = split_method_policy_segments(method)?;
672    let parse_policy = |s: &str| llm_policy::parse_policy_segment(s, RWKV_TRAIN_SCOPES);
673    let policy = policy_segment
674        .as_deref()
675        .map(parse_policy)
676        .transpose()
677        .context("failed to parse rwkv policy segment")?;
678
679    if let Some(path) = base.strip_prefix("file:") {
680        let p = PathBuf::from(path.trim());
681        if p.as_os_str().is_empty() {
682            bail!("empty file path in rwkv method");
683        }
684        if policy.as_ref().and_then(|p| p.load_from.as_ref()).is_some() {
685            bail!("rwkv method cannot use policy load_from together with file:<path>");
686        }
687        return Ok(MethodSpec::File { path: p, policy });
688    }
689
690    if let Some(cfg_s) = base.strip_prefix("cfg:") {
691        if !cfg_s.contains('=') {
692            return Ok(MethodSpec::Online {
693                cfg: parse_cfg_positional(cfg_s)?,
694                policy,
695            });
696        }
697        let mut cfg = OnlineConfig::default();
698        for pair in cfg_s.split(',') {
699            let pair = pair.trim();
700            if pair.is_empty() {
701                continue;
702            }
703            let (k, v) = pair
704                .split_once('=')
705                .with_context(|| format!("invalid cfg key/value pair '{pair}'"))?;
706            let key = k.trim().to_ascii_lowercase();
707            let val = v.trim();
708            match key.as_str() {
709                "hidden" => cfg.hidden = parse_usize(val, "hidden")?,
710                "layers" => cfg.layers = parse_usize(val, "layers")?,
711                "intermediate" => cfg.intermediate = parse_usize(val, "intermediate")?,
712                "decay_rank" => cfg.decay_rank = parse_usize(val, "decay_rank")?,
713                "a_rank" => cfg.a_rank = parse_usize(val, "a_rank")?,
714                "v_rank" => cfg.v_rank = parse_usize(val, "v_rank")?,
715                "g_rank" => cfg.g_rank = parse_usize(val, "g_rank")?,
716                "seed" => cfg.seed = parse_u64(val, "seed")?,
717                "lr" => cfg.lr = parse_f32(val, "lr")?,
718                "stride" => cfg.stride = parse_usize(val, "stride")?,
719                "train" | "train_mode" => cfg.train_mode = parse_train_mode_token(val)?,
720                other => bail!("unknown rwkv cfg key '{other}'"),
721            }
722        }
723        return Ok(MethodSpec::Online { cfg, policy });
724    }
725
726    let plain = PathBuf::from(base.trim());
727    if plain.exists() {
728        if policy.as_ref().and_then(|p| p.load_from.as_ref()).is_some() {
729            bail!("rwkv method cannot use policy load_from together with file path");
730        }
731        return Ok(MethodSpec::File {
732            path: plain,
733            policy,
734        });
735    }
736
737    if base.contains(',') {
738        return Ok(MethodSpec::Online {
739            cfg: parse_cfg_positional(&base)?,
740            policy,
741        });
742    }
743
744    bail!(
745        "rwkv method must be 'file:<path>', 'cfg:<k=v,...>', positional cfg CSV, or an existing model path"
746    );
747}
748
749// =============================================================================
750// File Header
751// =============================================================================
752
753/// Header structure for compressed data files.
754///
755/// Layout (18 bytes total):
756/// - magic: 4 bytes (little-endian u32)
757/// - version: 1 byte
758/// - coder: 1 byte (0=AC, 1=rANS)
759/// - original_len: 8 bytes (little-endian u64)
760/// - crc32: 4 bytes (little-endian u32)
761#[derive(Debug, Clone)]
762pub struct Header {
763    /// Magic number for format identification (must be MAGIC).
764    pub magic: u32,
765    /// Format version for compatibility checking.
766    pub version: u8,
767    /// Coder type used (0=AC, 1=rANS).
768    pub coder: u8,
769    /// Original uncompressed data length in bytes.
770    pub original_len: u64,
771    /// CRC32 checksum of original data for integrity verification.
772    pub crc32: u32,
773}
774
775impl Header {
776    /// Total header size in bytes.
777    pub const SIZE: usize = 4 + 1 + 1 + 8 + 4; // 18 bytes
778
779    /// Create a new header for compressed data.
780    pub fn new(coder: CoderType, original_len: u64, crc32: u32) -> Self {
781        Self {
782            magic: MAGIC,
783            version: VERSION,
784            coder: match coder {
785                CoderType::AC => 0,
786                CoderType::RANS => 1,
787            },
788            original_len,
789            crc32,
790        }
791    }
792
793    /// Serialize header to a writer (little-endian format).
794    pub fn write<W: Write>(&self, w: &mut W) -> Result<()> {
795        w.write_all(&self.magic.to_le_bytes())?;
796        w.write_all(&[self.version])?;
797        w.write_all(&[self.coder])?;
798        w.write_all(&self.original_len.to_le_bytes())?;
799        w.write_all(&self.crc32.to_le_bytes())?;
800        Ok(())
801    }
802
803    /// Deserialize header from a reader (little-endian format).
804    pub fn read<R: Read>(r: &mut R) -> Result<Self> {
805        let mut buf4 = [0u8; 4];
806        let mut buf8 = [0u8; 8];
807        let mut buf1 = [0u8; 1];
808
809        r.read_exact(&mut buf4)?;
810        let magic = u32::from_le_bytes(buf4);
811        if magic != MAGIC {
812            bail!(
813                "Invalid magic number: expected 0x{:08X}, got 0x{:08X}",
814                MAGIC,
815                magic
816            );
817        }
818
819        r.read_exact(&mut buf1)?;
820        let version = buf1[0];
821        if version > VERSION {
822            bail!(
823                "Unsupported version: {} (max supported: {})",
824                version,
825                VERSION
826            );
827        }
828
829        r.read_exact(&mut buf1)?;
830        let coder = buf1[0];
831
832        r.read_exact(&mut buf8)?;
833        let original_len = u64::from_le_bytes(buf8);
834
835        r.read_exact(&mut buf4)?;
836        let crc32 = u32::from_le_bytes(buf4);
837
838        Ok(Self {
839            magic,
840            version,
841            coder,
842            original_len,
843            crc32,
844        })
845    }
846
847    /// Get the coder type from the header byte.
848    pub fn coder_type(&self) -> CoderType {
849        match self.coder {
850            0 => CoderType::AC,
851            _ => CoderType::RANS,
852        }
853    }
854}
855
856// =============================================================================
857// CRC32 Checksum
858// =============================================================================
859
860/// Compute CRC32 checksum for data integrity verification.
861///
862/// Uses the crc32fast crate for hardware-accelerated computation.
863pub fn crc32(data: &[u8]) -> u32 {
864    crate::coders::crc32(data)
865}
866
867// =============================================================================
868// Compressor
869// =============================================================================
870
871/// Main compressor/decompressor that combines RWKV7 inference with entropy coding.
872///
873/// The compressor maintains internal state and pre-allocated buffers to minimize
874/// allocations during the compression/decompression hot path.
875pub struct Compressor {
876    /// RWKV7 model for generating probability distributions.
877    pub model: Arc<Model>,
878    /// Model state (recurrent hidden states).
879    pub state: State,
880    /// Scratch buffers for model forward passes.
881    pub scratch: ScratchBuffers,
882    /// Pre-allocated PDF buffer (eliminates allocations in compression loop).
883    pub pdf_buffer: Vec<f64>,
884    /// Reusable AC CDF buffer (vocab_size + 1 entries).
885    pub cdf_buffer_ac: Vec<u32>,
886    /// Scratch frequencies for AC quantization.
887    pub ac_freq_buffer: Vec<i64>,
888    /// Reusable rANS CDF buffer (vocab_size + 1 entries).
889    pub cdf_buffer_rans: Vec<u32>,
890    /// Scratch frequencies for rANS quantization.
891    pub rans_freq_buffer: Vec<i64>,
892    online: Option<OnlineRuntime>,
893    source_model_path: Option<PathBuf>,
894}
895
896impl Clone for Compressor {
897    fn clone(&self) -> Self {
898        let mut cloned = Self::new_from_model(self.model.clone());
899        cloned.state = self.state.clone();
900        cloned.pdf_buffer.clone_from(&self.pdf_buffer);
901        cloned.cdf_buffer_ac.clone_from(&self.cdf_buffer_ac);
902        cloned.ac_freq_buffer.clone_from(&self.ac_freq_buffer);
903        cloned.cdf_buffer_rans.clone_from(&self.cdf_buffer_rans);
904        cloned.rans_freq_buffer.clone_from(&self.rans_freq_buffer);
905        cloned.scratch = self.scratch.clone();
906        cloned.online = self.online.clone();
907        cloned.source_model_path = self.source_model_path.clone();
908        cloned
909    }
910}
911
912impl Compressor {
913    /// Create a new compressor with the given model.
914    ///
915    /// # Arguments
916    /// * `model_path` - Path to RWKV7 model weights (.safetensors format)
917    ///
918    /// # Returns
919    /// A new Compressor ready for compression/decompression operations.
920    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
921        let model_path = model_path.as_ref();
922        let model = Arc::new(Model::load(model_path)?);
923        let mut c = Self::new_from_model(model);
924        c.source_model_path = Some(model_path.to_path_buf());
925        c.maybe_load_sidecar()?;
926        Ok(c)
927    }
928
929    /// Load a model from disk and wrap it in `Arc`.
930    pub fn load_model<P: AsRef<Path>>(model_path: P) -> Result<Arc<Model>> {
931        Ok(Arc::new(Model::load(model_path)?))
932    }
933
934    /// Create a compressor from a preloaded model.
935    pub fn new_from_model(model: Arc<Model>) -> Self {
936        let state = model.new_state();
937        let vocab_size = model.config().vocab_size;
938        let scratch = ScratchBuffers::new(model.config());
939        Self {
940            model,
941            state,
942            scratch,
943            pdf_buffer: vec![0.0f64; vocab_size],
944            cdf_buffer_ac: vec![0u32; vocab_size + 1],
945            ac_freq_buffer: vec![0i64; vocab_size],
946            cdf_buffer_rans: vec![0u32; vocab_size + 1],
947            rans_freq_buffer: vec![0i64; vocab_size],
948            online: None,
949            source_model_path: None,
950        }
951    }
952
953    /// Create a compressor from a user method string.
954    pub fn new_from_method(method: &str) -> Result<Self> {
955        match parse_method_spec(method)? {
956            MethodSpec::File { path, policy } => {
957                let mut c = Self::new(&path)?;
958                if let Some(policy) = policy {
959                    let canonical_method =
960                        format!("file:{};policy:{}", path.display(), policy.canonical());
961                    let hidden = c.model.config().hidden_size;
962                    let mut online = c.online.take().unwrap_or_else(|| {
963                        OnlineRuntime::new(
964                            OnlineConfig::default(),
965                            canonical_method.clone(),
966                            Some(policy.clone()),
967                            VOCAB_SIZE,
968                            hidden,
969                        )
970                    });
971                    online.canonical_method = canonical_method;
972                    online.policy = Some(policy);
973                    online.needs_full_trace = online
974                        .policy
975                        .as_ref()
976                        .map(policy_needs_full_trace)
977                        .unwrap_or(false);
978                    c.online = Some(online);
979                    c.scratch.set_capture_train_trace(
980                        c.online.as_ref().is_some_and(|o| o.needs_full_trace),
981                    );
982                }
983                Ok(c)
984            }
985            MethodSpec::Online { cfg, policy } => {
986                let rwcfg = cfg.to_rwkv_config()?;
987                let model = if let Some(load_from) =
988                    policy.as_ref().and_then(|p| p.load_from.as_ref())
989                {
990                    let loaded = Arc::new(Model::load(load_from)?);
991                    let loaded_cfg = loaded.config();
992                    let shape_ok = loaded_cfg.vocab_size == rwcfg.vocab_size
993                        && loaded_cfg.hidden_size == rwcfg.hidden_size
994                        && loaded_cfg.num_layers == rwcfg.num_layers
995                        && loaded_cfg.num_heads == rwcfg.num_heads
996                        && loaded_cfg.head_dim == rwcfg.head_dim
997                        && loaded_cfg.intermediate_size == rwcfg.intermediate_size
998                        && loaded_cfg.decay_low_rank == rwcfg.decay_low_rank
999                        && loaded_cfg.a_low_rank == rwcfg.a_low_rank
1000                        && loaded_cfg.v_low_rank == rwcfg.v_low_rank
1001                        && loaded_cfg.g_low_rank == rwcfg.g_low_rank;
1002                    if !shape_ok {
1003                        bail!(
1004                            "rwkv policy load_from shape mismatch with cfg (strict match required)"
1005                        );
1006                    }
1007                    loaded
1008                } else {
1009                    Arc::new(Model::new_random(rwcfg, cfg.seed)?)
1010                };
1011                let mut c = Self::new_from_model(model);
1012                let mut canonical_method = cfg_to_method_string(&cfg);
1013                if let Some(policy) = policy.as_ref() {
1014                    canonical_method.push_str(";policy:");
1015                    canonical_method.push_str(&policy.canonical());
1016                }
1017                c.online = Some(OnlineRuntime::new(
1018                    cfg,
1019                    canonical_method,
1020                    policy,
1021                    VOCAB_SIZE,
1022                    c.model.config().hidden_size,
1023                ));
1024                c.scratch
1025                    .set_capture_train_trace(c.online.as_ref().is_some_and(|o| o.needs_full_trace));
1026                Ok(c)
1027            }
1028        }
1029    }
1030
1031    /// Reset the model state to initial values.
1032    ///
1033    /// Call this between independent compression/decompression operations
1034    /// to ensure a clean state.
1035    pub fn reset(&mut self) {
1036        self.state.reset();
1037        self.clear_online_training_buffers();
1038    }
1039
1040    fn prepare_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1041        if let Some(online) = self.online.as_mut() {
1042            online.prepare_policy_stream(total_symbols)?;
1043        }
1044        Ok(())
1045    }
1046
1047    #[inline]
1048    fn effective_full_bptt(scope: rwkv7::TrainScopeMask, bptt: usize) -> usize {
1049        if scope.trains_non_head_params() && bptt <= 1 {
1050            DEFAULT_FULL_TBPTT_WINDOW
1051        } else {
1052            bptt.max(1)
1053        }
1054    }
1055
1056    fn clear_online_training_buffers(&mut self) {
1057        if let Some(online) = self.online.as_mut()
1058            && let Some(tbptt) = online.full_tbptt.as_mut()
1059        {
1060            tbptt.pending_input_token = None;
1061            tbptt.pending_input_pre_state = None;
1062            tbptt.segment_start_state = None;
1063            tbptt.steps.clear();
1064            tbptt.settings = None;
1065        }
1066    }
1067
1068    fn forward_with_online_record(&mut self, token: u32) {
1069        if let Some(online) = self.online.as_mut()
1070            && let Some(tbptt) = online.full_tbptt.as_mut()
1071        {
1072            tbptt.pending_input_token = Some(token);
1073            tbptt.pending_input_pre_state = Some(self.state.clone());
1074        }
1075        let _ = self
1076            .model
1077            .forward(&mut self.scratch, token, &mut self.state);
1078    }
1079
1080    fn flush_full_tbptt_segment(&mut self) -> Result<()> {
1081        let extracted = {
1082            match self.online.as_mut() {
1083                Some(online) => match online.full_tbptt.as_mut() {
1084                    Some(tbptt) if !tbptt.steps.is_empty() => {
1085                        let settings = tbptt.settings.ok_or_else(|| {
1086                            anyhow::anyhow!("rwkv full tbptt settings are missing")
1087                        })?;
1088                        let start_state = tbptt.segment_start_state.clone().ok_or_else(|| {
1089                            anyhow::anyhow!("rwkv full tbptt segment start is missing")
1090                        })?;
1091                        let steps = tbptt.steps.clone();
1092                        tbptt.steps.clear();
1093                        tbptt.segment_start_state = None;
1094                        tbptt.settings = None;
1095                        let need_full_adam = matches!(settings.optimizer, OptimizerKind::Adam)
1096                            && settings.scope.trains_non_head_params()
1097                            && online.full_adam.is_none();
1098                        Some((settings, start_state, steps, need_full_adam))
1099                    }
1100                    _ => None,
1101                },
1102                None => None,
1103            }
1104        };
1105        let Some((settings, start_state, steps, need_full_adam)) = extracted else {
1106            return Ok(());
1107        };
1108
1109        if need_full_adam {
1110            let full_adam = self.model.new_full_adam_state();
1111            if let Some(online) = self.online.as_mut() {
1112                online.full_adam = Some(full_adam);
1113            }
1114        }
1115
1116        let model = Arc::make_mut(&mut self.model);
1117        let Some(online) = self.online.as_mut() else {
1118            return Ok(());
1119        };
1120        model.online_train_segment_tbptt(
1121            &mut self.scratch,
1122            &start_state,
1123            &steps,
1124            settings.scope,
1125            settings.optimizer,
1126            settings.lr,
1127            settings.clip,
1128            TBPTT_REPLAY_CHUNK,
1129            &mut online.adam_t,
1130            online.full_adam.as_mut(),
1131            if settings.scope.bias {
1132                Some(online.out_bias.as_mut_slice())
1133            } else {
1134                None
1135            },
1136            if settings.scope.bias {
1137                online.adam_m.as_deref_mut()
1138            } else {
1139                None
1140            },
1141            if settings.scope.bias {
1142                online.adam_v.as_deref_mut()
1143            } else {
1144                None
1145            },
1146            &mut self.state,
1147        )?;
1148        let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1149        Self::logits_to_pdf(self.scratch.logits(), bias, &mut self.pdf_buffer);
1150        Ok(())
1151    }
1152
1153    fn enqueue_full_tbptt_step(
1154        &mut self,
1155        settings: FullTrainSettings,
1156        target_symbol: u8,
1157    ) -> Result<()> {
1158        let should_flush = {
1159            let Some(online) = self.online.as_mut() else {
1160                return Ok(());
1161            };
1162            let Some(tbptt) = online.full_tbptt.as_mut() else {
1163                bail!("rwkv full-parameter online training requires trace-enabled tbptt runtime");
1164            };
1165            tbptt.settings.is_some_and(|prev| {
1166                !prev.matches(
1167                    settings.optimizer,
1168                    settings.lr,
1169                    settings.scope,
1170                    settings.bptt,
1171                    settings.clip,
1172                )
1173            }) && !tbptt.steps.is_empty()
1174        };
1175        if should_flush {
1176            self.flush_full_tbptt_segment()?;
1177        }
1178
1179        let flush_now = {
1180            let Some(online) = self.online.as_mut() else {
1181                return Ok(());
1182            };
1183            let Some(tbptt) = online.full_tbptt.as_mut() else {
1184                bail!("rwkv full-parameter online training requires trace-enabled tbptt runtime");
1185            };
1186            let Some(input_token) = tbptt.pending_input_token.take() else {
1187                return Ok(());
1188            };
1189            let input_pre_state = tbptt
1190                .pending_input_pre_state
1191                .take()
1192                .ok_or_else(|| anyhow::anyhow!("rwkv full tbptt pending pre-state is missing"))?;
1193            if tbptt.steps.is_empty() {
1194                tbptt.segment_start_state = Some(input_pre_state);
1195            }
1196            tbptt.settings = Some(settings);
1197            tbptt.steps.push((input_token, target_symbol));
1198            tbptt.steps.len() >= settings.bptt.max(1)
1199        };
1200        if flush_now {
1201            self.flush_full_tbptt_segment()?;
1202        }
1203        Ok(())
1204    }
1205
1206    /// Begin a policy stream with optional total symbol count.
1207    pub fn begin_online_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1208        self.finish_online_policy_stream()?;
1209        self.prepare_policy_stream(total_symbols)
1210    }
1211
1212    /// Flush any pending TBPTT segment while preserving the current predictive state.
1213    pub fn finish_online_policy_stream(&mut self) -> Result<()> {
1214        self.flush_full_tbptt_segment()
1215    }
1216
1217    /// Reset hidden state and TBPTT bookkeeping for a fresh stream.
1218    pub fn restart_online_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1219        self.finish_online_policy_stream()?;
1220        self.state.reset();
1221        self.clear_online_training_buffers();
1222        self.prepare_policy_stream(total_symbols)
1223    }
1224
1225    /// Reset state and prime the first predictive distribution.
1226    pub fn reset_and_prime(&mut self) {
1227        self.state.reset();
1228        self.clear_online_training_buffers();
1229        self.refresh_current_pdf(0);
1230    }
1231
1232    /// Capture runtime state for later restoration.
1233    pub fn snapshot_runtime(&self) -> RuntimeSnapshot {
1234        RuntimeSnapshot {
1235            model: self.model.clone(),
1236            scratch: self.scratch.clone(),
1237            state: self.state.clone(),
1238            pdf_buffer: self.pdf_buffer.clone(),
1239            online: self.online.clone(),
1240        }
1241    }
1242
1243    /// Restore previously captured runtime state.
1244    pub fn restore_runtime(&mut self, snapshot: &RuntimeSnapshot) {
1245        self.model = snapshot.model.clone();
1246        self.scratch = snapshot.scratch.clone();
1247        self.state = snapshot.state.clone();
1248        self.pdf_buffer.clone_from(&snapshot.pdf_buffer);
1249        self.online = snapshot.online.clone();
1250    }
1251
1252    /// Absorb a sequence of byte slices as conditioning context.
1253    pub fn absorb_chain(&mut self, parts: &[&[u8]]) -> Result<()> {
1254        let total = parts
1255            .iter()
1256            .fold(0u64, |acc, part| acc.saturating_add(part.len() as u64));
1257        self.fit_chain(parts, Some(total))
1258    }
1259
1260    /// Score bytes from the current predictive state.
1261    pub fn cross_entropy_from_current(&mut self, data: &[u8]) -> Result<f64> {
1262        if data.is_empty() {
1263            return Ok(0.0);
1264        }
1265        self.begin_online_policy_stream(Some(data.len() as u64))?;
1266        let mut total_bits = 0.0f64;
1267        for &byte in data {
1268            total_bits -= self.pdf_buffer[byte as usize].log2();
1269            self.observe_symbol_from_current_pdf(byte)?;
1270        }
1271        self.finish_online_policy_stream()?;
1272        Ok(total_bits / (data.len() as f64))
1273    }
1274
1275    /// Fit on `fit_parts`, then reset stream state and score `data` without further adaptation.
1276    pub fn cross_entropy_frozen_plugin_chain(
1277        &mut self,
1278        fit_parts: &[&[u8]],
1279        data: &[u8],
1280    ) -> Result<f64> {
1281        if data.is_empty() {
1282            return Ok(0.0);
1283        }
1284        if !self.can_adapt_online() {
1285            return self.cross_entropy(data);
1286        }
1287        self.finish_online_policy_stream()?;
1288        self.reset_and_prime();
1289        let fit_total = fit_parts
1290            .iter()
1291            .fold(0u64, |acc, part| acc.saturating_add(part.len() as u64));
1292        self.fit_chain(fit_parts, Some(fit_total))?;
1293        self.reset_and_prime();
1294
1295        let mut total_bits = 0.0f64;
1296        for &byte in data {
1297            total_bits -= self.pdf_buffer[byte as usize].max(1e-300).log2();
1298            self.advance_inference_only(byte);
1299        }
1300        Ok(total_bits / (data.len() as f64))
1301    }
1302
1303    /// Returns `true` when the compressor is in online-adaptation mode.
1304    pub fn is_online(&self) -> bool {
1305        self.online.is_some()
1306    }
1307
1308    /// Returns `true` when the current online configuration can actually adapt parameters.
1309    pub fn can_adapt_online(&self) -> bool {
1310        let Some(online) = &self.online else {
1311            return false;
1312        };
1313        match &online.policy {
1314            Some(policy) => llm_policy::policy_can_train(policy),
1315            None => !matches!(online.cfg.train_mode, OnlineTrainMode::None),
1316        }
1317    }
1318
1319    /// Number of tokens processed by the online updater.
1320    pub fn tokens_processed(&self) -> u64 {
1321        self.online.as_ref().map_or(0, |s| s.tokens_processed)
1322    }
1323
1324    /// Canonical method string for online mode, if enabled.
1325    pub fn online_method_string(&self) -> Option<&str> {
1326        self.online.as_ref().map(|s| s.canonical_method.as_str())
1327    }
1328
1329    /// Get the vocabulary size (should always be 256 for byte-level).
1330    pub fn vocab_size(&self) -> usize {
1331        self.model.config().vocab_size
1332    }
1333
1334    /// Apply optional online bias to logits and emit normalized PDF.
1335    pub fn online_apply_logits_bias(&self, logits: &[f32], pdf_out: &mut [f64]) {
1336        let bias = self.online.as_ref().map(|s| s.out_bias.as_slice());
1337        Self::logits_to_pdf(logits, bias, pdf_out);
1338    }
1339
1340    /// Convert logits (and optional bias) to a normalized PDF.
1341    pub fn logits_to_pdf(logits: &[f32], bias: Option<&[f32]>, pdf_out: &mut [f64]) {
1342        softmax_pdf_floor_with_bias(logits, bias, pdf_out);
1343    }
1344
1345    #[inline]
1346    /// Forward one token and emit the resulting (optionally biased) PDF.
1347    pub fn forward_to_pdf(&mut self, token: u32, pdf_out: &mut [f64]) {
1348        self.forward_with_online_record(token);
1349        let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1350        Self::logits_to_pdf(self.scratch.logits(), bias, pdf_out);
1351    }
1352
1353    #[inline]
1354    /// Refresh the internal cached PDF buffer from `token`.
1355    pub fn forward_to_internal_pdf(&mut self, token: u32) {
1356        self.refresh_current_pdf(token);
1357    }
1358
1359    #[inline]
1360    /// Copy the internal cached PDF into `pdf_out` (length must match vocab size).
1361    pub fn copy_current_pdf_to(&self, pdf_out: &mut [f64]) {
1362        assert_eq!(
1363            pdf_out.len(),
1364            self.pdf_buffer.len(),
1365            "rwkv pdf output length mismatch"
1366        );
1367        pdf_out.copy_from_slice(&self.pdf_buffer);
1368    }
1369
1370    /// Snapshot current online output bias, if online mode is active.
1371    pub fn online_bias_snapshot(&self) -> Option<Vec<f32>> {
1372        self.online.as_ref().map(|o| o.out_bias.clone())
1373    }
1374
1375    #[inline]
1376    /// Borrow online output bias vector when online mode is active.
1377    pub fn online_bias_slice(&self) -> Option<&[f32]> {
1378        self.online.as_ref().map(|o| o.out_bias.as_slice())
1379    }
1380
1381    #[inline]
1382    fn refresh_current_pdf(&mut self, token: u32) {
1383        self.forward_with_online_record(token);
1384        let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1385        Self::logits_to_pdf(self.scratch.logits(), bias, &mut self.pdf_buffer);
1386    }
1387
1388    fn fit_chain(&mut self, parts: &[&[u8]], total_symbols: Option<u64>) -> Result<()> {
1389        self.begin_online_policy_stream(total_symbols)?;
1390        for part in parts {
1391            for &byte in *part {
1392                self.observe_symbol_from_current_pdf(byte)?;
1393            }
1394        }
1395        self.finish_online_policy_stream()?;
1396        Ok(())
1397    }
1398
1399    #[inline]
1400    fn advance_inference_only(&mut self, symbol: u8) {
1401        self.refresh_current_pdf(symbol as u32);
1402    }
1403
1404    fn resolve_online_train_action(
1405        online: &mut OnlineRuntime,
1406    ) -> Result<(OptimizerKind, f32, u64, rwkv7::TrainScopeMask, usize, f32)> {
1407        let mut optimizer = match online.cfg.train_mode {
1408            OnlineTrainMode::None => OptimizerKind::Sgd,
1409            OnlineTrainMode::Sgd => OptimizerKind::Sgd,
1410            OnlineTrainMode::Adam => OptimizerKind::Adam,
1411        };
1412        let mut lr = online.cfg.lr.max(0.0);
1413        let mut stride = online.cfg.stride.max(1) as u64;
1414        let mut scope = rwkv7::TrainScopeMask::default();
1415        let default_train = !matches!(online.cfg.train_mode, OnlineTrainMode::None);
1416        scope.head = default_train;
1417        scope.bias = default_train;
1418        let mut bptt = 1usize;
1419        let mut clip = 0.0f32;
1420
1421        if let Some(action) = online.next_policy_action()? {
1422            match action {
1423                PolicyAction::Infer => {
1424                    scope = rwkv7::TrainScopeMask::default();
1425                }
1426                PolicyAction::Train(train) => {
1427                    optimizer = train.optimizer;
1428                    lr = train.hyper.lr.max(0.0);
1429                    stride = train.hyper.stride.max(1) as u64;
1430                    bptt = train.hyper.bptt.max(1);
1431                    clip = train.hyper.clip.max(0.0);
1432                    scope = scope_from_train_action(&train);
1433                }
1434            }
1435        }
1436
1437        Ok((optimizer, lr, stride, scope, bptt, clip))
1438    }
1439
1440    /// Apply one online update using externally supplied predictive PDF.
1441    pub fn online_update_from_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1442        self.online_update_with_pdf(symbol, pdf)
1443    }
1444
1445    #[inline]
1446    /// Update online state from `pdf`, then advance model state with `symbol`.
1447    pub fn observe_symbol_from_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1448        self.online_update_with_pdf(symbol, pdf)?;
1449        self.refresh_current_pdf(symbol as u32);
1450        Ok(())
1451    }
1452
1453    fn online_update_with_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1454        let (optimizer, lr, stride_hit, scope, bptt, clip) = {
1455            let Some(online) = self.online.as_mut() else {
1456                return Ok(());
1457            };
1458            online.tokens_processed = online.tokens_processed.saturating_add(1);
1459            let (optimizer, lr, stride, scope, bptt, clip) =
1460                Self::resolve_online_train_action(online)?;
1461            let mut stride_hit = false;
1462            if scope.trains_any_params() {
1463                online.policy_train_steps = online.policy_train_steps.saturating_add(1);
1464                stride_hit = stride <= 1 || (online.policy_train_steps % stride) == 0;
1465            }
1466            (optimizer, lr, stride_hit, scope, bptt, clip)
1467        };
1468
1469        if !scope.trains_any_params() || !stride_hit || lr == 0.0 {
1470            self.flush_full_tbptt_segment()?;
1471            if let Some(online) = self.online.as_mut()
1472                && let Some(tbptt) = online.full_tbptt.as_mut()
1473            {
1474                tbptt.pending_input_token = None;
1475                tbptt.pending_input_pre_state = None;
1476            }
1477            return Ok(());
1478        }
1479
1480        if matches!(optimizer, OptimizerKind::Adam)
1481            && let Some(online) = self.online.as_mut()
1482            && scope.bias
1483            && (online.adam_m.is_none() || online.adam_v.is_none())
1484        {
1485            online.adam_m = Some(vec![0.0; online.out_bias.len()]);
1486            online.adam_v = Some(vec![0.0; online.out_bias.len()]);
1487        }
1488
1489        if !scope.trains_non_head_params() {
1490            let hidden = self.scratch.lm_head_input().to_vec();
1491            let pdf_snapshot = pdf.to_vec();
1492            self.flush_full_tbptt_segment()?;
1493            let Some(online) = self.online.as_mut() else {
1494                return Ok(());
1495            };
1496            if let Some(tbptt) = online.full_tbptt.as_mut() {
1497                tbptt.pending_input_token = None;
1498                tbptt.pending_input_pre_state = None;
1499            }
1500            let model = Arc::make_mut(&mut self.model);
1501            apply_online_lm_head_update(
1502                model,
1503                online,
1504                &hidden,
1505                symbol,
1506                &pdf_snapshot,
1507                lr,
1508                optimizer,
1509                scope.head,
1510                scope.bias,
1511                clip,
1512            );
1513            return Ok(());
1514        }
1515
1516        let settings = FullTrainSettings {
1517            optimizer,
1518            lr,
1519            scope,
1520            bptt: Self::effective_full_bptt(scope, bptt),
1521            clip,
1522        };
1523        self.enqueue_full_tbptt_step(settings, symbol)
1524    }
1525
1526    fn online_update_from_current_pdf(&mut self, symbol: u8) -> Result<()> {
1527        let pdf_snapshot = self.pdf_buffer.clone();
1528        self.online_update_with_pdf(symbol, &pdf_snapshot)
1529    }
1530
1531    #[inline]
1532    /// Update online state using current internal PDF, then consume `symbol`.
1533    pub fn observe_symbol_from_current_pdf(&mut self, symbol: u8) -> Result<()> {
1534        self.online_update_from_current_pdf(symbol)?;
1535        self.refresh_current_pdf(symbol as u32);
1536        Ok(())
1537    }
1538
1539    /// Export model weights and JSON sidecar metadata.
1540    pub fn export_online<P: AsRef<Path>>(&self, model_path: P) -> Result<()> {
1541        let model_path = model_path.as_ref();
1542        self.model.save_safetensors(model_path)?;
1543        let opt_sidecar = optimizer_sidecar_path(model_path);
1544
1545        let sidecar = model_path.with_extension("json");
1546        let meta = if let Some(online) = &self.online {
1547            if let Some(full_adam) = online.full_adam.as_ref() {
1548                self.model
1549                    .save_full_adam_safetensors(full_adam, &opt_sidecar)?;
1550            } else if opt_sidecar.exists() {
1551                let _ = fs::remove_file(&opt_sidecar);
1552            }
1553            let train_mode = match online.cfg.train_mode {
1554                OnlineTrainMode::None => "none",
1555                OnlineTrainMode::Sgd => "sgd",
1556                OnlineTrainMode::Adam => "adam",
1557            };
1558            json!({
1559                "version": 1,
1560                "method": online.canonical_method,
1561                "policy": online.policy.as_ref().map(LlmPolicy::canonical),
1562                "policy_cursor": online.policy_runtime.as_ref().map(PolicyRuntime::cursor).unwrap_or(0),
1563                "policy_stream_total": online.policy_stream_total,
1564                "policy_train_steps": online.policy_train_steps,
1565                "training_mode": train_mode,
1566                "tokens_processed": online.tokens_processed,
1567                "adam_t": online.adam_t,
1568                "has_full_adam": online.full_adam.is_some(),
1569                "config": {
1570                    "hidden": online.cfg.hidden,
1571                    "layers": online.cfg.layers,
1572                    "intermediate": online.cfg.intermediate,
1573                    "decay_rank": online.cfg.decay_rank,
1574                    "a_rank": online.cfg.a_rank,
1575                    "v_rank": online.cfg.v_rank,
1576                    "g_rank": online.cfg.g_rank,
1577                    "seed": online.cfg.seed,
1578                    "lr": online.cfg.lr,
1579                    "stride": online.cfg.stride.max(1),
1580                },
1581                "output_bias": online.out_bias,
1582                "adam_m": online.adam_m,
1583                "adam_v": online.adam_v,
1584                "lm_head_adam_m": online.lm_head_adam_m,
1585                "lm_head_adam_v": online.lm_head_adam_v,
1586            })
1587        } else {
1588            if opt_sidecar.exists() {
1589                let _ = fs::remove_file(&opt_sidecar);
1590            }
1591            json!({
1592                "version": 1,
1593                "method": format!("file:{}", model_path.display()),
1594                "training_mode": "none",
1595                "tokens_processed": 0,
1596            })
1597        };
1598
1599        fs::write(&sidecar, serde_json::to_vec_pretty(&meta)?)?;
1600        Ok(())
1601    }
1602
1603    fn maybe_load_sidecar(&mut self) -> Result<()> {
1604        let Some(model_path) = &self.source_model_path else {
1605            return Ok(());
1606        };
1607        let sidecar = model_path.with_extension("json");
1608        if !sidecar.exists() {
1609            return Ok(());
1610        }
1611        let raw = fs::read(&sidecar)?;
1612        let v: serde_json::Value = serde_json::from_slice(&raw)?;
1613        let parse_vec_f32 = |key: &str| -> Option<Vec<f32>> {
1614            v.get(key).and_then(|arr| arr.as_array()).map(|arr| {
1615                arr.iter()
1616                    .map(|x| x.as_f64().unwrap_or(0.0) as f32)
1617                    .collect::<Vec<f32>>()
1618            })
1619        };
1620        let output_bias = v
1621            .get("output_bias")
1622            .and_then(|arr| arr.as_array())
1623            .map(|arr| {
1624                arr.iter()
1625                    .map(|x| x.as_f64().unwrap_or(0.0) as f32)
1626                    .collect::<Vec<f32>>()
1627            });
1628        let method = v
1629            .get("method")
1630            .and_then(|m| m.as_str())
1631            .map(|s| s.to_string())
1632            .unwrap_or_else(|| format!("file:{}", model_path.display()));
1633        let has_full_adam = v
1634            .get("has_full_adam")
1635            .and_then(|x| x.as_bool())
1636            .unwrap_or(false);
1637        let policy = v
1638            .get("policy")
1639            .and_then(|p| p.as_str())
1640            .and_then(|s| llm_policy::parse_policy_segment(s, RWKV_TRAIN_SCOPES).ok());
1641        let tokens = v
1642            .get("tokens_processed")
1643            .and_then(|t| t.as_u64())
1644            .unwrap_or(0);
1645        if let Some(mut out_bias) = output_bias {
1646            out_bias.resize(self.vocab_size(), 0.0);
1647            let mut cfg = OnlineConfig::default();
1648            if let Some(cfg_v) = v.get("config").and_then(|x| x.as_object()) {
1649                if let Some(x) = cfg_v.get("hidden").and_then(|x| x.as_u64()) {
1650                    cfg.hidden = x as usize;
1651                }
1652                if let Some(x) = cfg_v.get("layers").and_then(|x| x.as_u64()) {
1653                    cfg.layers = x as usize;
1654                }
1655                if let Some(x) = cfg_v.get("intermediate").and_then(|x| x.as_u64()) {
1656                    cfg.intermediate = x as usize;
1657                }
1658                if let Some(x) = cfg_v.get("decay_rank").and_then(|x| x.as_u64()) {
1659                    cfg.decay_rank = x as usize;
1660                }
1661                if let Some(x) = cfg_v.get("a_rank").and_then(|x| x.as_u64()) {
1662                    cfg.a_rank = x as usize;
1663                }
1664                if let Some(x) = cfg_v.get("v_rank").and_then(|x| x.as_u64()) {
1665                    cfg.v_rank = x as usize;
1666                }
1667                if let Some(x) = cfg_v.get("g_rank").and_then(|x| x.as_u64()) {
1668                    cfg.g_rank = x as usize;
1669                }
1670                if let Some(x) = cfg_v.get("seed").and_then(|x| x.as_u64()) {
1671                    cfg.seed = x;
1672                }
1673                if let Some(x) = cfg_v.get("lr").and_then(|x| x.as_f64()) {
1674                    cfg.lr = x as f32;
1675                }
1676                if let Some(x) = cfg_v.get("stride").and_then(|x| x.as_u64()) {
1677                    cfg.stride = (x as usize).max(1);
1678                }
1679            }
1680            cfg.train_mode = v
1681                .get("training_mode")
1682                .and_then(|x| x.as_str())
1683                .and_then(|s| parse_train_mode_token(s).ok())
1684                .unwrap_or(OnlineTrainMode::None);
1685            let needs_full_trace = policy
1686                .as_ref()
1687                .map(policy_needs_full_trace)
1688                .unwrap_or(false);
1689            self.online = Some(OnlineRuntime {
1690                cfg,
1691                canonical_method: method,
1692                policy,
1693                policy_runtime: None,
1694                needs_full_trace,
1695                policy_stream_total: v.get("policy_stream_total").and_then(|x| x.as_u64()),
1696                policy_train_steps: v
1697                    .get("policy_train_steps")
1698                    .and_then(|x| x.as_u64())
1699                    .unwrap_or(0),
1700                tokens_processed: tokens,
1701                out_bias,
1702                adam_m: parse_vec_f32("adam_m"),
1703                adam_v: parse_vec_f32("adam_v"),
1704                full_adam: None,
1705                lm_head_adam_m: parse_vec_f32("lm_head_adam_m"),
1706                lm_head_adam_v: parse_vec_f32("lm_head_adam_v"),
1707                adam_t: v.get("adam_t").and_then(|x| x.as_u64()).unwrap_or(0) as usize,
1708                full_tbptt: needs_full_trace.then(|| FullTbpttRuntime {
1709                    pending_input_token: None,
1710                    pending_input_pre_state: None,
1711                    segment_start_state: None,
1712                    steps: Vec::new(),
1713                    settings: None,
1714                }),
1715            });
1716            let opt_sidecar = optimizer_sidecar_path(model_path);
1717            if opt_sidecar.exists() {
1718                if let Some(online) = self.online.as_mut() {
1719                    online.full_adam = Some(self.model.load_full_adam_safetensors(&opt_sidecar)?);
1720                }
1721            } else if has_full_adam {
1722                bail!(
1723                    "missing optimizer sidecar '{}' required for exact online resume",
1724                    opt_sidecar.display()
1725                );
1726            }
1727            if let Some(cursor) = v.get("policy_cursor").and_then(|x| x.as_u64())
1728                && let Some(online) = self.online.as_mut()
1729                && online.policy.is_some()
1730            {
1731                let train_steps = online.policy_train_steps;
1732                online.prepare_policy_stream(online.policy_stream_total)?;
1733                online.policy_train_steps = train_steps;
1734                if let Some(rt) = online.policy_runtime.as_mut() {
1735                    rt.set_cursor(cursor);
1736                }
1737            }
1738            self.scratch
1739                .set_capture_train_trace(self.online.as_ref().is_some_and(|o| o.needs_full_trace));
1740        }
1741        Ok(())
1742    }
1743
1744    /// Compress data using the specified entropy coder.
1745    ///
1746    /// # Arguments
1747    /// * `data` - Raw bytes to compress
1748    /// * `coder` - Entropy coder to use (AC or rANS)
1749    ///
1750    /// # Returns
1751    /// Compressed data including header with checksum.
1752    pub fn compress(&mut self, data: &[u8], coder: CoderType) -> Result<Vec<u8>> {
1753        let mut output = Vec::new();
1754        self.compress_into(data, coder, &mut output)?;
1755        Ok(output)
1756    }
1757
1758    /// Compress into an arbitrary writer.
1759    pub fn compress_into<W: Write>(
1760        &mut self,
1761        data: &[u8],
1762        coder: CoderType,
1763        w: &mut W,
1764    ) -> Result<()> {
1765        self.restart_online_policy_stream(Some(data.len() as u64))?;
1766
1767        let checksum = crc32(data);
1768        let header = Header::new(coder, data.len() as u64, checksum);
1769        header.write(w)?;
1770
1771        match coder {
1772            CoderType::AC => self.compress_ac(data, w)?,
1773            CoderType::RANS => self.compress_rans(data, w)?,
1774        }
1775
1776        Ok(())
1777    }
1778
1779    /// Compress a chain of byte slices into an arbitrary writer.
1780    pub fn compress_chain_into<W: Write>(
1781        &mut self,
1782        parts: &[&[u8]],
1783        coder: CoderType,
1784        w: &mut W,
1785    ) -> Result<()> {
1786        let mut total_len: u64 = 0;
1787        let mut hasher = crc32fast::Hasher::new();
1788        for p in parts {
1789            total_len = total_len.saturating_add(p.len() as u64);
1790            hasher.update(p);
1791        }
1792        let checksum = hasher.finalize();
1793        self.restart_online_policy_stream(Some(total_len))?;
1794
1795        let header = Header::new(coder, total_len, checksum);
1796        header.write(w)?;
1797
1798        let it = parts.iter().flat_map(|p| p.iter().copied());
1799        match coder {
1800            CoderType::AC => self.compress_ac_iter(it, w)?,
1801            CoderType::RANS => self.compress_rans_iter(it, w)?,
1802        }
1803
1804        Ok(())
1805    }
1806
1807    /// Return compressed byte size without materializing output bytes.
1808    pub fn compress_size(&mut self, data: &[u8], coder: CoderType) -> Result<u64> {
1809        let mut w = CountingWriter::new();
1810        self.compress_into(data, coder, &mut w)?;
1811        Ok(w.bytes_written())
1812    }
1813
1814    /// Return compressed byte size for chained inputs.
1815    pub fn compress_size_chain(&mut self, parts: &[&[u8]], coder: CoderType) -> Result<u64> {
1816        let mut w = CountingWriter::new();
1817        self.compress_chain_into(parts, coder, &mut w)?;
1818        Ok(w.bytes_written())
1819    }
1820
1821    /// Compress using arithmetic coding.
1822    fn compress_ac<W: Write>(&mut self, data: &[u8], output: &mut W) -> Result<()> {
1823        self.compress_ac_iter(data.iter().copied(), output)
1824    }
1825
1826    fn compress_ac_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
1827    where
1828        I: IntoIterator<Item = u8>,
1829    {
1830        let mut encoder = ArithmeticEncoder::new(output);
1831
1832        // Prime the model with a null byte to establish initial state
1833        self.refresh_current_pdf(0);
1834
1835        for byte in data {
1836            quantize_pdf_to_cdf_with_buffer(
1837                &self.pdf_buffer,
1838                &mut self.cdf_buffer_ac,
1839                &mut self.ac_freq_buffer,
1840            );
1841            let sym = byte as usize;
1842            let c_lo = self.cdf_buffer_ac[sym] as u64;
1843            let c_hi = self.cdf_buffer_ac[sym + 1] as u64;
1844            encoder.encode_counts(c_lo, c_hi, CDF_TOTAL as u64)?;
1845            self.observe_symbol_from_current_pdf(byte)?;
1846        }
1847
1848        let _ = encoder.finish()?;
1849        self.finish_online_policy_stream()?;
1850        Ok(())
1851    }
1852
1853    /// Compress using rANS coding with block-based encoding.
1854    fn compress_rans<W: Write>(&mut self, data: &[u8], output: &mut W) -> Result<()> {
1855        self.compress_rans_iter(data.iter().copied(), output)
1856    }
1857
1858    fn compress_rans_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
1859    where
1860        I: IntoIterator<Item = u8>,
1861    {
1862        // Use blocked encoder (128KB blocks) for streaming large files
1863        let mut encoder = BlockedRansEncoder::new();
1864
1865        // Prime the model with a null byte
1866        self.refresh_current_pdf(0);
1867
1868        for byte in data {
1869            quantize_pdf_to_rans_cdf_with_buffer(
1870                &self.pdf_buffer,
1871                &mut self.cdf_buffer_rans,
1872                &mut self.rans_freq_buffer,
1873            );
1874            let sym = byte as usize;
1875            let cdf = Cdf::new(
1876                self.cdf_buffer_rans[sym],
1877                self.cdf_buffer_rans[sym + 1],
1878                ANS_TOTAL,
1879            );
1880            encoder.encode(cdf);
1881            self.observe_symbol_from_current_pdf(byte)?;
1882        }
1883
1884        // Finish encoding and write blocks
1885        let blocks = encoder.finish();
1886
1887        // Write block count
1888        output.write_all(&(blocks.len() as u32).to_le_bytes())?;
1889
1890        // Write each block with length prefix
1891        for block in &blocks {
1892            output.write_all(&(block.len() as u32).to_le_bytes())?;
1893            output.write_all(block)?;
1894        }
1895
1896        self.finish_online_policy_stream()?;
1897        Ok(())
1898    }
1899
1900    /// Decompress data.
1901    ///
1902    /// # Arguments
1903    /// * `data` - Compressed data (must include header)
1904    ///
1905    /// # Returns
1906    /// Original decompressed data. Returns error if checksum doesn't match.
1907    pub fn decompress(&mut self, data: &[u8]) -> Result<Vec<u8>> {
1908        let mut cursor = Cursor::new(data);
1909        let header = Header::read(&mut cursor)?;
1910
1911        self.restart_online_policy_stream(Some(header.original_len))?;
1912
1913        let compressed = &data[Header::SIZE..];
1914        let result = match header.coder_type() {
1915            CoderType::AC => self.decompress_ac(compressed, header.original_len as usize)?,
1916            CoderType::RANS => self.decompress_rans(compressed, header.original_len as usize)?,
1917        };
1918
1919        // Verify checksum for data integrity
1920        let actual_crc = crc32(&result);
1921        if actual_crc != header.crc32 {
1922            bail!(
1923                "CRC32 mismatch: expected 0x{:08X}, got 0x{:08X}",
1924                header.crc32,
1925                actual_crc
1926            );
1927        }
1928
1929        Ok(result)
1930    }
1931
1932    /// Decompress using arithmetic coding.
1933    fn decompress_ac(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
1934        let mut decoder = ArithmeticDecoder::new(compressed)?;
1935
1936        let mut result = Vec::with_capacity(original_len);
1937
1938        // Prime with null byte (must match compression)
1939        self.refresh_current_pdf(0);
1940
1941        for _ in 0..original_len {
1942            quantize_pdf_to_cdf_with_buffer(
1943                &self.pdf_buffer,
1944                &mut self.cdf_buffer_ac,
1945                &mut self.ac_freq_buffer,
1946            );
1947            let sym = decoder.decode_symbol_counts(&self.cdf_buffer_ac, CDF_TOTAL)?;
1948            result.push(sym as u8);
1949            self.observe_symbol_from_current_pdf(sym as u8)?;
1950        }
1951
1952        self.finish_online_policy_stream()?;
1953        Ok(result)
1954    }
1955
1956    /// Decompress using rANS coding.
1957    fn decompress_rans(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
1958        // Read block count
1959        if compressed.len() < 4 {
1960            bail!("rANS data too short");
1961        }
1962        let block_count =
1963            u32::from_le_bytes([compressed[0], compressed[1], compressed[2], compressed[3]])
1964                as usize;
1965
1966        // Read blocks
1967        let mut blocks = Vec::with_capacity(block_count);
1968        let mut pos = 4;
1969
1970        for _ in 0..block_count {
1971            if pos + 4 > compressed.len() {
1972                bail!("Truncated block header");
1973            }
1974            let block_len = u32::from_le_bytes([
1975                compressed[pos],
1976                compressed[pos + 1],
1977                compressed[pos + 2],
1978                compressed[pos + 3],
1979            ]) as usize;
1980            pos += 4;
1981
1982            if pos + block_len > compressed.len() {
1983                bail!("Truncated block data");
1984            }
1985            blocks.push(&compressed[pos..pos + block_len]);
1986            pos += block_len;
1987        }
1988
1989        // Decode using blocked decoder
1990        let mut decoder = BlockedRansDecoder::new(blocks, original_len)?;
1991        let mut result = Vec::with_capacity(original_len);
1992
1993        // Prime with null byte
1994        self.refresh_current_pdf(0);
1995
1996        for _ in 0..original_len {
1997            quantize_pdf_to_rans_cdf_with_buffer(
1998                &self.pdf_buffer,
1999                &mut self.cdf_buffer_rans,
2000                &mut self.rans_freq_buffer,
2001            );
2002            let sym = decoder.decode(&self.cdf_buffer_rans)?;
2003            result.push(sym as u8);
2004            self.observe_symbol_from_current_pdf(sym as u8)?;
2005        }
2006
2007        self.finish_online_policy_stream()?;
2008        Ok(result)
2009    }
2010
2011    /// Calculate cross-entropy (bits per byte) for data without compression.
2012    ///
2013    /// This measures how well the model predicts the data, giving a theoretical
2014    /// lower bound on achievable compression. Useful for evaluating model quality.
2015    ///
2016    /// # Arguments
2017    /// * `data` - Data to analyze
2018    ///
2019    /// # Returns
2020    /// Average bits per byte (lower is better, 8.0 means no compression possible).
2021    pub fn cross_entropy(&mut self, data: &[u8]) -> Result<f64> {
2022        self.finish_online_policy_stream()?;
2023        self.reset_and_prime();
2024        self.cross_entropy_from_current(data)
2025    }
2026
2027    /// Cross entropy conditioned on chained prefix slices.
2028    pub fn cross_entropy_conditional_chain(
2029        &mut self,
2030        prefix_parts: &[&[u8]],
2031        data: &[u8],
2032    ) -> Result<f64> {
2033        if data.is_empty() {
2034            return Ok(0.0);
2035        }
2036        let prefix_len = prefix_parts
2037            .iter()
2038            .fold(0usize, |acc, p| acc.saturating_add(p.len()));
2039        self.finish_online_policy_stream()?;
2040        self.reset_and_prime();
2041        self.fit_chain(prefix_parts, Some((prefix_len + data.len()) as u64))?;
2042
2043        let mut total_bits = 0.0f64;
2044        for &byte in data {
2045            total_bits -= self.pdf_buffer[byte as usize].log2();
2046            self.observe_symbol_from_current_pdf(byte)?;
2047        }
2048
2049        self.finish_online_policy_stream()?;
2050        Ok(total_bits / (data.len() as f64))
2051    }
2052
2053    /// Cross entropy conditioned on a single prefix slice.
2054    pub fn cross_entropy_conditional(&mut self, prefix: &[u8], data: &[u8]) -> Result<f64> {
2055        if data.is_empty() {
2056            return Ok(0.0);
2057        }
2058
2059        self.finish_online_policy_stream()?;
2060        self.reset_and_prime();
2061        self.begin_online_policy_stream(Some((prefix.len() + data.len()) as u64))?;
2062
2063        // Condition on prefix (update state, no scoring)
2064        for &byte in prefix {
2065            self.observe_symbol_from_current_pdf(byte)?;
2066        }
2067
2068        let mut total_bits = 0.0f64;
2069        for &byte in data {
2070            total_bits -= self.pdf_buffer[byte as usize].log2();
2071            self.observe_symbol_from_current_pdf(byte)?;
2072        }
2073
2074        self.finish_online_policy_stream()?;
2075        Ok(total_bits / (data.len() as f64))
2076    }
2077
2078    /// Symmetric aligned joint cross entropy using the better ordering.
2079    pub fn joint_cross_entropy_aligned_min(&mut self, x: &[u8], y: &[u8]) -> Result<f64> {
2080        let n = x.len().min(y.len());
2081        if n == 0 {
2082            return Ok(0.0);
2083        }
2084
2085        let h_xy = self.joint_cross_entropy_aligned_order(x, y, false)?;
2086        let h_yx = self.joint_cross_entropy_aligned_order(x, y, true)?;
2087        Ok(h_xy.min(h_yx))
2088    }
2089
2090    fn joint_cross_entropy_aligned_order(&mut self, x: &[u8], y: &[u8], swap: bool) -> Result<f64> {
2091        let n = x.len().min(y.len());
2092        if n == 0 {
2093            return Ok(0.0);
2094        }
2095
2096        self.restart_online_policy_stream(Some((2 * n) as u64))?;
2097
2098        self.refresh_current_pdf(0);
2099
2100        let mut total_bits = 0.0f64;
2101        for i in 0..n {
2102            let a = if swap { y[i] } else { x[i] };
2103            let b = if swap { x[i] } else { y[i] };
2104
2105            let pa = self.pdf_buffer[a as usize];
2106            total_bits -= pa.log2();
2107            self.observe_symbol_from_current_pdf(a)?;
2108
2109            let pb = self.pdf_buffer[b as usize];
2110            total_bits -= pb.log2();
2111            self.observe_symbol_from_current_pdf(b)?;
2112        }
2113
2114        self.finish_online_policy_stream()?;
2115        Ok(total_bits / (n as f64))
2116    }
2117}
2118
2119// =============================================================================
2120// Compression Statistics
2121// =============================================================================
2122
2123/// Statistics from a compression operation.
2124#[derive(Debug, Clone)]
2125pub struct CompressionStats {
2126    /// Original size in bytes.
2127    pub original_size: usize,
2128    /// Compressed size in bytes (including header).
2129    pub compressed_size: usize,
2130    /// Compression ratio (original/compressed). Higher is better.
2131    pub ratio: f64,
2132    /// Bits per byte. Lower is better (theoretical minimum: ~0, maximum: 8).
2133    pub bits_per_byte: f64,
2134    /// Time taken in seconds.
2135    pub time_seconds: f64,
2136    /// Throughput in bytes per second.
2137    pub throughput: f64,
2138}
2139
2140impl std::fmt::Display for CompressionStats {
2141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2142        write!(
2143            f,
2144            "{} bytes -> {} bytes | ratio={:.3} | bits/byte={:.3} | time={:.2}s | {:.0} B/s",
2145            self.original_size,
2146            self.compressed_size,
2147            self.ratio,
2148            self.bits_per_byte,
2149            self.time_seconds,
2150            self.throughput,
2151        )
2152    }
2153}
2154
2155/// Compress data and return both the compressed output and statistics.
2156///
2157/// This is a convenience function that wraps `Compressor::compress` with timing.
2158pub fn compress_with_stats(
2159    compressor: &mut Compressor,
2160    data: &[u8],
2161    coder: CoderType,
2162) -> Result<(Vec<u8>, CompressionStats)> {
2163    let start = std::time::Instant::now();
2164    let compressed = compressor.compress(data, coder)?;
2165    let elapsed = start.elapsed().as_secs_f64();
2166
2167    let stats = CompressionStats {
2168        original_size: data.len(),
2169        compressed_size: compressed.len(),
2170        ratio: data.len() as f64 / compressed.len() as f64,
2171        bits_per_byte: (compressed.len() as f64 * 8.0) / data.len() as f64,
2172        time_seconds: elapsed,
2173        throughput: data.len() as f64 / elapsed,
2174    };
2175
2176    Ok((compressed, stats))
2177}
2178
2179// =============================================================================
2180// Tests
2181// =============================================================================
2182
2183#[cfg(test)]
2184mod tests {
2185    use super::*;
2186    use std::time::{SystemTime, UNIX_EPOCH};
2187
2188    fn temp_path(name: &str, ext: &str) -> PathBuf {
2189        let ts = SystemTime::now()
2190            .duration_since(UNIX_EPOCH)
2191            .unwrap()
2192            .as_nanos();
2193        std::env::temp_dir().join(format!("infotheory_rwkvzip_{name}_{ts}.{ext}"))
2194    }
2195
2196    #[test]
2197    fn test_header_roundtrip() {
2198        let header = Header::new(CoderType::AC, 12345, 0xDEADBEEF);
2199
2200        let mut buf = Vec::new();
2201        header.write(&mut buf).unwrap();
2202
2203        assert_eq!(buf.len(), Header::SIZE);
2204
2205        let mut cursor = Cursor::new(&buf);
2206        let read_header = Header::read(&mut cursor).unwrap();
2207
2208        assert_eq!(read_header.magic, MAGIC);
2209        assert_eq!(read_header.version, VERSION);
2210        assert_eq!(read_header.coder, 0);
2211        assert_eq!(read_header.original_len, 12345);
2212        assert_eq!(read_header.crc32, 0xDEADBEEF);
2213    }
2214
2215    #[test]
2216    fn test_header_rans() {
2217        let header = Header::new(CoderType::RANS, 67890, 0xCAFEBABE);
2218        assert_eq!(header.coder, 1);
2219        assert_eq!(header.coder_type(), CoderType::RANS);
2220    }
2221
2222    #[test]
2223    fn test_coder_type_display() {
2224        assert_eq!(format!("{}", CoderType::AC), "AC");
2225        assert_eq!(format!("{}", CoderType::RANS), "rANS");
2226    }
2227
2228    #[test]
2229    fn test_crc32() {
2230        let data = b"Hello, World!";
2231        let c = crc32(data);
2232        assert_ne!(c, 0);
2233        // CRC32 should be deterministic
2234        assert_eq!(c, crc32(data));
2235    }
2236
2237    #[test]
2238    fn test_crc32_different_data() {
2239        let c1 = crc32(b"Hello");
2240        let c2 = crc32(b"World");
2241        assert_ne!(c1, c2);
2242    }
2243
2244    #[test]
2245    fn test_crc32_known_vector() {
2246        // Standard CRC-32 (ISO-HDLC) test vector.
2247        assert_eq!(crc32(b"123456789"), 0xCBF4_3926);
2248    }
2249
2250    #[test]
2251    fn test_header_rejects_invalid_magic() {
2252        let mut buf = Vec::new();
2253        let header = Header::new(CoderType::AC, 1, 2);
2254        header.write(&mut buf).unwrap();
2255        // Corrupt magic.
2256        buf[0] ^= 0xFF;
2257
2258        let mut cursor = Cursor::new(&buf);
2259        let err = Header::read(&mut cursor).unwrap_err();
2260        let msg = format!("{err:#}");
2261        assert!(msg.contains("Invalid magic number"));
2262    }
2263
2264    #[test]
2265    fn test_parse_method_spec_file_and_cfg() {
2266        let p = temp_path("dummy", "bin");
2267        std::fs::write(&p, b"x").unwrap();
2268
2269        match parse_method_spec(&format!("file:{}", p.display())).unwrap() {
2270            MethodSpec::File { path: got, .. } => assert_eq!(got, p),
2271            _ => panic!("expected file method"),
2272        }
2273
2274        match parse_method_spec(
2275            "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=1,train=none,lr=0.01,stride=2;policy:schedule=0..100:infer",
2276        )
2277        .unwrap()
2278        {
2279            MethodSpec::Online { cfg, .. } => {
2280                assert_eq!(cfg.hidden, 64);
2281                assert_eq!(cfg.layers, 1);
2282                assert_eq!(cfg.seed, 1);
2283                assert_eq!(cfg.stride, 2);
2284            }
2285            _ => panic!("expected cfg method"),
2286        }
2287
2288        match parse_method_spec("64,64,1,0,7,0.01,2;policy:schedule=0..100:infer").unwrap() {
2289            MethodSpec::Online { cfg, .. } => {
2290                assert_eq!(cfg.hidden, 64);
2291                assert_eq!(cfg.intermediate, 64);
2292                assert_eq!(cfg.layers, 1);
2293                assert_eq!(cfg.seed, 7);
2294                assert_eq!(cfg.stride, 2);
2295            }
2296            _ => panic!("expected positional cfg method"),
2297        }
2298
2299        // Backward-compatible plain existing path.
2300        match parse_method_spec(&p.display().to_string()).unwrap() {
2301            MethodSpec::File { path: got, .. } => assert_eq!(got, p),
2302            _ => panic!("expected file method"),
2303        }
2304
2305        std::fs::remove_file(&p).ok();
2306    }
2307
2308    #[test]
2309    fn test_parse_method_spec_rejects_unknown_cfg_key() {
2310        let err =
2311            parse_method_spec("cfg:hidden=64,wat=1;policy:schedule=0..100:infer").unwrap_err();
2312        assert!(format!("{err:#}").contains("unknown rwkv cfg key"));
2313    }
2314
2315    #[test]
2316    fn test_parse_method_spec_accepts_cfg_without_policy() {
2317        let spec = parse_method_spec("cfg:hidden=64,layers=1,intermediate=64").unwrap();
2318        match spec {
2319            MethodSpec::Online { cfg, policy } => {
2320                assert_eq!(cfg.hidden, 64);
2321                assert_eq!(cfg.layers, 1);
2322                assert_eq!(cfg.intermediate, 64);
2323                assert!(policy.is_none());
2324            }
2325            _ => panic!("expected cfg method"),
2326        }
2327    }
2328
2329    #[test]
2330    fn test_canonical_method_omits_policy_when_absent() {
2331        let c = Compressor::new_from_method("cfg:hidden=64,layers=1,intermediate=64").unwrap();
2332        assert_eq!(
2333            c.online_method_string(),
2334            Some(
2335                "cfg:hidden=64,layers=1,intermediate=64,decay_rank=32,a_rank=32,v_rank=32,g_rank=64,seed=0,train=none,lr=0.001,stride=1"
2336            )
2337        );
2338    }
2339
2340    #[test]
2341    fn test_online_export_reload_roundtrip() {
2342        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=7,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.9)";
2343        let data = b"rwkv online export/load deterministic sample";
2344
2345        let mut c1 = Compressor::new_from_method(method).unwrap();
2346        let _ = c1.compress(data, CoderType::AC).unwrap();
2347
2348        let model_path = temp_path("export", "safetensors");
2349        c1.export_online(&model_path).unwrap();
2350        let out1_after_export = c1.compress(data, CoderType::AC).unwrap();
2351
2352        let mut c2 = Compressor::new(&model_path).unwrap();
2353        let out2 = c2.compress(data, CoderType::AC).unwrap();
2354
2355        assert_eq!(out1_after_export, out2);
2356        assert!(model_path.with_extension("json").exists());
2357
2358        std::fs::remove_file(&model_path).ok();
2359        std::fs::remove_file(model_path.with_extension("json")).ok();
2360    }
2361
2362    #[test]
2363    fn test_runtime_snapshot_restores_online_state() {
2364        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=9,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.9)";
2365        let mut c = Compressor::new_from_method(method).unwrap();
2366        c.reset_and_prime();
2367        c.absorb_chain(&[b"prior context".as_slice()]).unwrap();
2368        let snap = c.snapshot_runtime();
2369
2370        c.absorb_chain(&[b"snippet-a".as_slice()]).unwrap();
2371        let score_a = c.cross_entropy_from_current(b"query").unwrap();
2372
2373        c.restore_runtime(&snap);
2374        c.absorb_chain(&[b"snippet-b".as_slice()]).unwrap();
2375        let score_b = c.cross_entropy_from_current(b"query").unwrap();
2376
2377        c.restore_runtime(&snap);
2378        c.absorb_chain(&[b"snippet-b".as_slice()]).unwrap();
2379        let score_b_again = c.cross_entropy_from_current(b"query").unwrap();
2380
2381        assert!((score_b - score_b_again).abs() < 1e-12);
2382        let _ = score_a;
2383    }
2384
2385    #[test]
2386    fn test_runtime_snapshot_restores_non_head_training_state() {
2387        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=15,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2388        let mut c = Compressor::new_from_method(method).unwrap();
2389        c.reset_and_prime();
2390        c.absorb_chain(&[b"prior context".as_slice()]).unwrap();
2391        let snap = c.snapshot_runtime();
2392
2393        let _ = c
2394            .cross_entropy_from_current(b"mutate model before restore")
2395            .unwrap();
2396
2397        c.restore_runtime(&snap);
2398        let score_a = c
2399            .cross_entropy_from_current(b"query after restore")
2400            .unwrap();
2401
2402        c.restore_runtime(&snap);
2403        let score_b = c
2404            .cross_entropy_from_current(b"query after restore")
2405            .unwrap();
2406
2407        assert!((score_a - score_b).abs() < 1e-12);
2408    }
2409
2410    #[test]
2411    fn test_online_training_updates_lm_head_weights() {
2412        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=5,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.9)";
2413        let mut c = Compressor::new_from_method(method).unwrap();
2414        c.reset_and_prime();
2415        let before = c.model.lm_head_weights()[0..64].to_vec();
2416        let _ = c
2417            .cross_entropy_from_current(b"online rwkv weight update")
2418            .unwrap();
2419        let after = &c.model.lm_head_weights()[0..64];
2420        let mut changed = false;
2421        for i in 0..before.len() {
2422            if before[i].to_bits() != after[i].to_bits() {
2423                changed = true;
2424                break;
2425            }
2426        }
2427        assert!(
2428            changed,
2429            "expected LM-head weights to change under online training"
2430        );
2431    }
2432
2433    #[test]
2434    fn test_cross_entropy_from_current_keeps_unique_model_arc() {
2435        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=21,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2436        let mut c = Compressor::new_from_method(method).unwrap();
2437        c.reset_and_prime();
2438
2439        assert_eq!(Arc::strong_count(&c.model), 1);
2440        let before = Arc::as_ptr(&c.model);
2441        let _ = c.cross_entropy_from_current(b"arc uniqueness").unwrap();
2442        let after = Arc::as_ptr(&c.model);
2443
2444        assert_eq!(Arc::strong_count(&c.model), 1);
2445        assert_eq!(before, after);
2446    }
2447
2448    #[test]
2449    fn test_online_training_non_head_scope_updates_model_params() {
2450        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=13,train=sgd,lr=0.005,stride=1;policy:schedule=0..100:train(scope=attn,opt=sgd,lr=0.005,stride=1,bptt=1,clip=0,momentum=0.9)";
2451        let mut c = Compressor::new_from_method(method).unwrap();
2452
2453        let head_before = c.model.lm_head_weights()[0..64].to_vec();
2454        let before_path = temp_path("rwkv_non_head_before", "safetensors");
2455        let after_path = temp_path("rwkv_non_head_after", "safetensors");
2456        c.model.save_safetensors(&before_path).unwrap();
2457
2458        c.reset_and_prime();
2459        let _ = c
2460            .cross_entropy_from_current(b"rwkv non head online update")
2461            .unwrap();
2462        c.model.save_safetensors(&after_path).unwrap();
2463
2464        let head_after = &c.model.lm_head_weights()[0..64];
2465        for idx in 0..head_before.len() {
2466            assert_eq!(
2467                head_before[idx].to_bits(),
2468                head_after[idx].to_bits(),
2469                "lm-head changed under scope=attn at index {idx}"
2470            );
2471        }
2472
2473        let before_bytes = std::fs::read(&before_path).unwrap();
2474        let after_bytes = std::fs::read(&after_path).unwrap();
2475        assert_ne!(
2476            before_bytes, after_bytes,
2477            "expected non-head params to change"
2478        );
2479
2480        std::fs::remove_file(&before_path).ok();
2481        std::fs::remove_file(&after_path).ok();
2482    }
2483
2484    #[test]
2485    fn test_online_training_scope_all_bptt_gt_one_supported() {
2486        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=23,train=adam,lr=0.001,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.001,stride=1,bptt=2,clip=0,momentum=0.9)";
2487        let mut c = Compressor::new_from_method(method).unwrap();
2488        let before_path = temp_path("rwkv_tbptt_before", "safetensors");
2489        let after_path = temp_path("rwkv_tbptt_after", "safetensors");
2490        c.model.save_safetensors(&before_path).unwrap();
2491        c.reset_and_prime();
2492        let score = c.cross_entropy_from_current(b"abcdef").unwrap();
2493        assert!(score.is_finite());
2494        c.model.save_safetensors(&after_path).unwrap();
2495        let before_bytes = std::fs::read(&before_path).unwrap();
2496        let after_bytes = std::fs::read(&after_path).unwrap();
2497        assert_ne!(
2498            before_bytes, after_bytes,
2499            "expected tbptt training to update params"
2500        );
2501        std::fs::remove_file(&before_path).ok();
2502        std::fs::remove_file(&after_path).ok();
2503    }
2504
2505    #[test]
2506    fn test_online_training_scope_all_bptt_one_uses_fast_default_window() {
2507        let method_bptt1 = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=27,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2508        let method_bptt8 = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=27,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=8,clip=0,momentum=0.9)";
2509        let data = b"abcdefghij";
2510
2511        let mut c1 = Compressor::new_from_method(method_bptt1).unwrap();
2512        let mut c8 = Compressor::new_from_method(method_bptt8).unwrap();
2513
2514        let score1 = c1.cross_entropy(data).unwrap();
2515        let score8 = c8.cross_entropy(data).unwrap();
2516        assert!((score1 - score8).abs() < 1e-12);
2517
2518        let bptt1_path = temp_path("rwkv_bptt1_fast_default", "safetensors");
2519        let bptt8_path = temp_path("rwkv_bptt8_fast_default", "safetensors");
2520        c1.model.save_safetensors(&bptt1_path).unwrap();
2521        c8.model.save_safetensors(&bptt8_path).unwrap();
2522        let bptt1_bytes = std::fs::read(&bptt1_path).unwrap();
2523        let bptt8_bytes = std::fs::read(&bptt8_path).unwrap();
2524        assert_eq!(bptt1_bytes, bptt8_bytes);
2525        std::fs::remove_file(&bptt1_path).ok();
2526        std::fs::remove_file(&bptt8_path).ok();
2527    }
2528
2529    #[test]
2530    fn test_online_training_full_tbptt_updates_first_symbol_after_priming() {
2531        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=33,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=8,clip=0,momentum=0.9)";
2532        let mut c = Compressor::new_from_method(method).unwrap();
2533        let before_path = temp_path("rwkv_first_symbol_before", "safetensors");
2534        let after_path = temp_path("rwkv_first_symbol_after", "safetensors");
2535        c.model.save_safetensors(&before_path).unwrap();
2536
2537        c.reset_and_prime();
2538        let score = c.cross_entropy_from_current(b"a").unwrap();
2539        assert!(score.is_finite());
2540        c.model.save_safetensors(&after_path).unwrap();
2541
2542        let before_bytes = std::fs::read(&before_path).unwrap();
2543        let after_bytes = std::fs::read(&after_path).unwrap();
2544        assert_ne!(
2545            before_bytes, after_bytes,
2546            "expected the first symbol after priming to update params"
2547        );
2548        std::fs::remove_file(&before_path).ok();
2549        std::fs::remove_file(&after_path).ok();
2550    }
2551
2552    #[test]
2553    fn test_online_export_reload_roundtrip_preserves_full_adam_resume() {
2554        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2555        let data = b"rwkv full-adam export/load deterministic continuation sample";
2556
2557        let mut c1 = Compressor::new_from_method(method).unwrap();
2558        let _ = c1.compress(data, CoderType::AC).unwrap();
2559
2560        let model_path = temp_path("rwkv_full_adam_export", "safetensors");
2561        let opt_path = optimizer_sidecar_path(&model_path);
2562        c1.export_online(&model_path).unwrap();
2563        assert!(
2564            opt_path.exists(),
2565            "expected optimizer sidecar to be exported"
2566        );
2567        let out1_after_export = c1.compress(data, CoderType::AC).unwrap();
2568
2569        let mut c2 = Compressor::new(&model_path).unwrap();
2570        let out2 = c2.compress(data, CoderType::AC).unwrap();
2571        assert_eq!(out1_after_export, out2);
2572
2573        std::fs::remove_file(&model_path).ok();
2574        std::fs::remove_file(model_path.with_extension("json")).ok();
2575        std::fs::remove_file(&opt_path).ok();
2576    }
2577
2578    #[test]
2579    fn test_online_export_reload_missing_full_adam_sidecar_fails() {
2580        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=41,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2581        let mut c = Compressor::new_from_method(method).unwrap();
2582        let _ = c
2583            .compress(b"rwkv strict optimizer-sidecar requirement", CoderType::AC)
2584            .unwrap();
2585
2586        let model_path = temp_path("rwkv_full_adam_missing_sidecar", "safetensors");
2587        let opt_path = optimizer_sidecar_path(&model_path);
2588        c.export_online(&model_path).unwrap();
2589        std::fs::remove_file(&opt_path).unwrap();
2590
2591        let err = match Compressor::new(&model_path) {
2592            Ok(_) => panic!("expected missing optimizer sidecar to fail"),
2593            Err(err) => err,
2594        };
2595        assert!(format!("{err:#}").contains("missing optimizer sidecar"));
2596
2597        std::fs::remove_file(&model_path).ok();
2598        std::fs::remove_file(model_path.with_extension("json")).ok();
2599    }
2600
2601    #[test]
2602    fn test_clone_preserves_non_head_training_trace() {
2603        let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=43,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
2604        let mut c = Compressor::new_from_method(method).unwrap();
2605        c.reset_and_prime();
2606        c.absorb_chain(&[b"clone trace prefix".as_slice()]).unwrap();
2607
2608        let mut cloned = c.clone();
2609        let score = cloned
2610            .cross_entropy_from_current(b"clone trace query")
2611            .unwrap();
2612        assert!(score.is_finite());
2613    }
2614}