infotheory/backends/mambazip/
mod.rs

1#![allow(clippy::items_after_test_module)]
2
3// mambazip - deterministic CPU-first Mamba-1 compressor/runtime.
4
5use anyhow::{Context, Result, bail};
6use serde_json::json;
7use std::fs;
8use std::io::{Cursor, Read, Write};
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11
12use crate::backends::llm_policy::{
13    self, LlmPolicy, OptimizerKind, PolicyAction, PolicyRuntime, split_method_policy_segments,
14};
15/// Mamba-1 model internals.
16pub mod mamba1;
17
18/// Shared entropy coders.
19pub use crate::coders;
20/// Backward-compatible coder re-export.
21pub use crate::coders::CoderType;
22
23use crate::coders::{
24    ANS_TOTAL, ArithmeticDecoder, ArithmeticEncoder, BlockedRansDecoder, BlockedRansEncoder,
25    CDF_TOTAL, Cdf, quantize_pdf_to_cdf_with_buffer, quantize_pdf_to_rans_cdf_with_buffer,
26};
27
28/// Mamba model config.
29pub use mamba1::Config;
30/// Mamba model.
31pub use mamba1::Model;
32/// Mamba reusable scratch buffers.
33pub use mamba1::ScratchBuffers;
34/// Mamba recurrent state.
35pub use mamba1::State;
36
37/// File format magic for mambazip payloads.
38pub const MAGIC: u32 = 0x5a424d4d; // "MMBZ"
39/// File format version.
40pub const VERSION: u8 = 1;
41/// Byte vocabulary size.
42pub const VOCAB_SIZE: usize = 256;
43const TBPTT_REPLAY_CHUNK: usize = 32;
44const MAMBA_TRAIN_SCOPES: &[&str] = &[
45    "embed",
46    "layer_norm",
47    "mixer_conv",
48    "mixer_ssm",
49    "mixer_proj",
50    "head",
51    "bias",
52    "all",
53    "none",
54];
55
56#[inline]
57fn optimizer_sidecar_path(model_path: &Path) -> PathBuf {
58    model_path.with_extension("opt.safetensors")
59}
60
61#[derive(Clone, Copy, Debug, PartialEq, Eq)]
62/// Online adaptation mode for Mamba output-bias updates.
63pub enum OnlineTrainMode {
64    /// Disable updates.
65    None,
66    /// SGD updates.
67    Sgd,
68    /// Adam updates.
69    Adam,
70}
71
72#[derive(Clone, Debug)]
73/// Online/runtime model construction config.
74pub struct OnlineConfig {
75    /// Hidden width.
76    pub hidden: usize,
77    /// Number of layers.
78    pub layers: usize,
79    /// Mamba inner width (`d_inner`).
80    pub intermediate: usize,
81    /// Mamba state width (`d_state`).
82    pub state: usize,
83    /// Mamba depthwise conv width (`d_conv`).
84    pub conv: usize,
85    /// Delta rank (`dt_rank`).
86    pub dt_rank: usize,
87    /// RNG seed for random init.
88    pub seed: u64,
89    /// Online training mode.
90    pub train_mode: OnlineTrainMode,
91    /// Learning rate.
92    pub lr: f32,
93    /// Update stride.
94    pub stride: usize,
95}
96
97impl Default for OnlineConfig {
98    fn default() -> Self {
99        Self {
100            hidden: 256,
101            layers: 6,
102            intermediate: 512,
103            state: 16,
104            conv: 4,
105            dt_rank: 16,
106            seed: 0,
107            train_mode: OnlineTrainMode::None,
108            lr: 0.001,
109            stride: 1,
110        }
111    }
112}
113
114impl OnlineConfig {
115    /// Convert to validated model config.
116    pub fn to_mamba_config(&self) -> Result<Config> {
117        let cfg = Config {
118            vocab_size: VOCAB_SIZE,
119            hidden_size: self.hidden.max(16),
120            num_layers: self.layers.max(1),
121            inner_size: self.intermediate.max(16),
122            state_size: self.state.max(1),
123            conv_kernel: self.conv.max(1),
124            dt_rank: self.dt_rank.max(1),
125            layer_norm_eps: 1e-5,
126        };
127        cfg.validate()?;
128        Ok(cfg)
129    }
130}
131
132#[derive(Clone, Debug)]
133/// Parsed method specification.
134pub enum MethodSpec {
135    /// Load model from filesystem.
136    File {
137        /// Path to `.safetensors` model weights.
138        path: PathBuf,
139        /// Optional runtime training/inference policy.
140        policy: Option<LlmPolicy>,
141    },
142    /// Construct random model + online adaptation config.
143    Online {
144        /// Online model/training configuration.
145        cfg: OnlineConfig,
146        /// Optional runtime training/inference policy.
147        policy: Option<LlmPolicy>,
148    },
149}
150
151#[derive(Clone)]
152struct OnlineRuntime {
153    cfg: OnlineConfig,
154    canonical_method: String,
155    policy: Option<LlmPolicy>,
156    policy_runtime: Option<PolicyRuntime>,
157    needs_full_trace: bool,
158    policy_stream_total: Option<u64>,
159    policy_train_steps: u64,
160    tokens_processed: u64,
161    out_bias: Vec<f32>,
162    adam_m: Option<Vec<f32>>,
163    adam_v: Option<Vec<f32>>,
164    full_adam: Option<mamba1::FullAdamState>,
165    lm_head_adam_m: Option<Vec<f32>>,
166    lm_head_adam_v: Option<Vec<f32>>,
167    adam_t: usize,
168    full_tbptt: Option<FullTbpttRuntime>,
169}
170
171#[derive(Clone, Copy, Debug)]
172struct FullTrainSettings {
173    optimizer: OptimizerKind,
174    lr: f32,
175    scope: mamba1::TrainScopeMask,
176    bptt: usize,
177    clip: f32,
178}
179
180impl FullTrainSettings {
181    fn matches(
182        self,
183        optimizer: OptimizerKind,
184        lr: f32,
185        scope: mamba1::TrainScopeMask,
186        bptt: usize,
187        clip: f32,
188    ) -> bool {
189        self.optimizer == optimizer
190            && self.lr.to_bits() == lr.to_bits()
191            && self.scope.embed == scope.embed
192            && self.scope.layer_norm == scope.layer_norm
193            && self.scope.mixer_conv == scope.mixer_conv
194            && self.scope.mixer_ssm == scope.mixer_ssm
195            && self.scope.mixer_proj == scope.mixer_proj
196            && self.scope.head == scope.head
197            && self.scope.bias == scope.bias
198            && self.bptt == bptt
199            && self.clip.to_bits() == clip.to_bits()
200    }
201}
202
203#[derive(Clone)]
204struct FullTbpttStep {
205    input_token: u32,
206    target_symbol: u8,
207    pdf: Vec<f64>,
208}
209
210#[derive(Clone)]
211struct FullTbpttRuntime {
212    pending_input_token: Option<u32>,
213    pending_input_pre_state: Option<State>,
214    segment_start_state: Option<State>,
215    steps: Vec<FullTbpttStep>,
216    settings: Option<FullTrainSettings>,
217}
218
219#[derive(Clone)]
220/// Snapshot of mutable runtime state.
221pub struct RuntimeSnapshot {
222    model: Arc<Model>,
223    scratch: ScratchBuffers,
224    state: State,
225    pdf_buffer: Vec<f64>,
226    online: Option<OnlineRuntime>,
227}
228
229impl OnlineRuntime {
230    fn new(
231        cfg: OnlineConfig,
232        canonical_method: String,
233        policy: Option<LlmPolicy>,
234        vocab_size: usize,
235        hidden_size: usize,
236    ) -> Self {
237        let mut use_adam = matches!(cfg.train_mode, OnlineTrainMode::Adam);
238        if let Some(pol) = &policy {
239            use_adam = policy_uses_adam(pol) || use_adam;
240        }
241        let needs_full_trace = policy
242            .as_ref()
243            .map(policy_needs_full_trace)
244            .unwrap_or(false);
245        Self {
246            canonical_method,
247            cfg,
248            policy,
249            policy_runtime: None,
250            needs_full_trace,
251            policy_stream_total: None,
252            policy_train_steps: 0,
253            tokens_processed: 0,
254            out_bias: vec![0.0; vocab_size],
255            adam_m: use_adam.then(|| vec![0.0; vocab_size]),
256            adam_v: use_adam.then(|| vec![0.0; vocab_size]),
257            full_adam: None,
258            lm_head_adam_m: use_adam.then(|| vec![0.0; vocab_size * hidden_size]),
259            lm_head_adam_v: use_adam.then(|| vec![0.0; vocab_size * hidden_size]),
260            adam_t: 0,
261            full_tbptt: needs_full_trace.then(|| FullTbpttRuntime {
262                pending_input_token: None,
263                pending_input_pre_state: None,
264                segment_start_state: None,
265                steps: Vec::new(),
266                settings: None,
267            }),
268        }
269    }
270
271    fn prepare_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
272        self.policy_stream_total = total_symbols;
273        self.policy_train_steps = 0;
274        if let Some(tbptt) = self.full_tbptt.as_mut() {
275            // Preserve pending predictive edge while resetting segment bookkeeping.
276            tbptt.segment_start_state = None;
277            tbptt.steps.clear();
278            tbptt.settings = None;
279        }
280        self.policy_runtime = match &self.policy {
281            Some(p) => Some(PolicyRuntime::new(p.compile(total_symbols)?)),
282            None => None,
283        };
284        Ok(())
285    }
286
287    #[inline]
288    fn next_policy_action(&mut self) -> Result<Option<PolicyAction>> {
289        if self.policy.is_none() {
290            return Ok(None);
291        }
292        if self.policy_runtime.is_none() {
293            self.prepare_policy_stream(None)?;
294        }
295        Ok(self.policy_runtime.as_mut().map(PolicyRuntime::next_action))
296    }
297}
298
299fn policy_uses_adam(policy: &LlmPolicy) -> bool {
300    use llm_policy::ScheduleRule;
301    for rule in &policy.schedule {
302        match rule {
303            ScheduleRule::Interval(interval) => {
304                if let PolicyAction::Train(train) = &interval.action
305                    && matches!(train.optimizer, OptimizerKind::Adam)
306                {
307                    return true;
308                }
309            }
310            ScheduleRule::Repeat(repeat) => {
311                for seg in &repeat.pattern {
312                    if let PolicyAction::Train(train) = &seg.action
313                        && matches!(train.optimizer, OptimizerKind::Adam)
314                    {
315                        return true;
316                    }
317                }
318            }
319        }
320    }
321    false
322}
323
324fn scope_needs_full_trace(scope: &llm_policy::TrainScopeSet) -> bool {
325    scope.all
326        || scope.contains("embed")
327        || scope.contains("layer_norm")
328        || scope.contains("mixer_conv")
329        || scope.contains("mixer_ssm")
330        || scope.contains("mixer_proj")
331}
332
333fn policy_needs_full_trace(policy: &LlmPolicy) -> bool {
334    use llm_policy::ScheduleRule;
335    for rule in &policy.schedule {
336        match rule {
337            ScheduleRule::Interval(interval) => {
338                if let PolicyAction::Train(train) = &interval.action
339                    && scope_needs_full_trace(&train.scope)
340                {
341                    return true;
342                }
343            }
344            ScheduleRule::Repeat(repeat) => {
345                for seg in &repeat.pattern {
346                    if let PolicyAction::Train(train) = &seg.action
347                        && scope_needs_full_trace(&train.scope)
348                    {
349                        return true;
350                    }
351                }
352            }
353        }
354    }
355    false
356}
357
358fn cfg_to_method_string(cfg: &OnlineConfig) -> String {
359    let train = match cfg.train_mode {
360        OnlineTrainMode::None => "none",
361        OnlineTrainMode::Sgd => "sgd",
362        OnlineTrainMode::Adam => "adam",
363    };
364    format!(
365        "cfg:hidden={},layers={},intermediate={},state={},conv={},dt_rank={},seed={},train={},lr={},stride={}",
366        cfg.hidden,
367        cfg.layers,
368        cfg.intermediate,
369        cfg.state,
370        cfg.conv,
371        cfg.dt_rank,
372        cfg.seed,
373        train,
374        cfg.lr,
375        cfg.stride.max(1),
376    )
377}
378
379fn softmax_pdf_floor_with_bias(logits: &[f32], bias: Option<&[f32]>, pdf_out: &mut [f64]) {
380    debug_assert_eq!(logits.len(), pdf_out.len());
381    if let Some(b) = bias {
382        debug_assert_eq!(b.len(), logits.len());
383    }
384    if logits.is_empty() {
385        return;
386    }
387
388    let mut max_logit = f32::NEG_INFINITY;
389    if let Some(b) = bias {
390        for i in 0..logits.len() {
391            let z = logits[i] + b[i];
392            if z > max_logit {
393                max_logit = z;
394            }
395        }
396    } else {
397        for &z in logits {
398            if z > max_logit {
399                max_logit = z;
400            }
401        }
402    }
403
404    let mut sum = 0.0f64;
405    if let Some(b) = bias {
406        for i in 0..logits.len() {
407            let p = ((logits[i] + b[i] - max_logit) as f64).exp();
408            pdf_out[i] = p;
409            sum += p;
410        }
411    } else {
412        for i in 0..logits.len() {
413            let p = ((logits[i] - max_logit) as f64).exp();
414            pdf_out[i] = p;
415            sum += p;
416        }
417    }
418
419    let inv_sum = if sum.is_finite() && sum > 0.0 {
420        1.0 / sum
421    } else {
422        1.0 / (logits.len() as f64)
423    };
424
425    let floor = 1e-12f64;
426    let mut norm = 0.0f64;
427    for p in pdf_out.iter_mut() {
428        *p = (*p * inv_sum).max(floor);
429        norm += *p;
430    }
431    let inv_norm = if norm.is_finite() && norm > 0.0 {
432        1.0 / norm
433    } else {
434        1.0 / (logits.len() as f64)
435    };
436    for p in pdf_out.iter_mut() {
437        *p *= inv_norm;
438    }
439}
440
441fn parse_u64(v: &str, key: &str) -> Result<u64> {
442    v.parse::<u64>()
443        .with_context(|| format!("invalid integer value for '{key}': {v}"))
444}
445
446fn parse_usize(v: &str, key: &str) -> Result<usize> {
447    v.parse::<usize>()
448        .with_context(|| format!("invalid integer value for '{key}': {v}"))
449}
450
451fn parse_f32(v: &str, key: &str) -> Result<f32> {
452    v.parse::<f32>()
453        .with_context(|| format!("invalid float value for '{key}': {v}"))
454}
455
456fn parse_train_mode_token(v: &str) -> Result<OnlineTrainMode> {
457    let code = v.trim().to_ascii_lowercase();
458    match code.as_str() {
459        "0" | "none" | "off" => Ok(OnlineTrainMode::None),
460        "1" | "sgd" => Ok(OnlineTrainMode::Sgd),
461        "2" | "adam" => Ok(OnlineTrainMode::Adam),
462        other => bail!("unknown train mode '{other}'"),
463    }
464}
465
466fn parse_cfg_positional(csv: &str) -> Result<OnlineConfig> {
467    let vals: Vec<&str> = csv
468        .split(',')
469        .map(|s| s.trim())
470        .filter(|s| !s.is_empty())
471        .collect();
472    if vals.len() != 6 && vals.len() != 7 {
473        bail!(
474            "positional cfg format expects 6 or 7 values: hidden,intermediate,layers,train,seed,lr[,stride]"
475        );
476    }
477
478    Ok(OnlineConfig {
479        hidden: parse_usize(vals[0], "hidden")?,
480        intermediate: parse_usize(vals[1], "intermediate")?,
481        layers: parse_usize(vals[2], "layers")?,
482        train_mode: parse_train_mode_token(vals[3])?,
483        seed: parse_u64(vals[4], "seed")?,
484        lr: parse_f32(vals[5], "lr")?,
485        stride: if vals.len() == 7 {
486            parse_usize(vals[6], "stride")?
487        } else {
488            1
489        },
490        ..OnlineConfig::default()
491    })
492}
493
494/// Parse a method string.
495///
496/// Supported formats:
497/// - `file:/path/to/model.safetensors`
498/// - `file:/path/to/model.safetensors;policy:...`
499/// - `cfg:key=value,...[;policy:...]`
500/// - positional `cfg` CSV
501/// - existing model path
502pub fn parse_method_spec(method: &str) -> Result<MethodSpec> {
503    let (base, policy_segment) = split_method_policy_segments(method)?;
504    let parse_policy = |s: &str| llm_policy::parse_policy_segment(s, MAMBA_TRAIN_SCOPES);
505    let policy = policy_segment
506        .as_deref()
507        .map(parse_policy)
508        .transpose()
509        .context("failed to parse mamba policy segment")?;
510
511    if let Some(path) = base.strip_prefix("file:") {
512        let p = PathBuf::from(path.trim());
513        if p.as_os_str().is_empty() {
514            bail!("empty file path in mamba method");
515        }
516        if policy.as_ref().and_then(|p| p.load_from.as_ref()).is_some() {
517            bail!("mamba method cannot use policy load_from together with file:<path>");
518        }
519        return Ok(MethodSpec::File { path: p, policy });
520    }
521
522    if let Some(cfg_s) = base.strip_prefix("cfg:") {
523        if !cfg_s.contains('=') {
524            return Ok(MethodSpec::Online {
525                cfg: parse_cfg_positional(cfg_s)?,
526                policy,
527            });
528        }
529        let mut cfg = OnlineConfig::default();
530        for pair in cfg_s.split(',') {
531            let pair = pair.trim();
532            if pair.is_empty() {
533                continue;
534            }
535            let (k, v) = pair
536                .split_once('=')
537                .with_context(|| format!("invalid cfg key/value pair '{pair}'"))?;
538            let key = k.trim().to_ascii_lowercase();
539            let val = v.trim();
540            match key.as_str() {
541                "hidden" => cfg.hidden = parse_usize(val, "hidden")?,
542                "layers" => cfg.layers = parse_usize(val, "layers")?,
543                "intermediate" => cfg.intermediate = parse_usize(val, "intermediate")?,
544                "state" | "d_state" => cfg.state = parse_usize(val, "state")?,
545                "conv" | "d_conv" => cfg.conv = parse_usize(val, "conv")?,
546                "dt_rank" => cfg.dt_rank = parse_usize(val, "dt_rank")?,
547                "seed" => cfg.seed = parse_u64(val, "seed")?,
548                "lr" => cfg.lr = parse_f32(val, "lr")?,
549                "stride" => cfg.stride = parse_usize(val, "stride")?,
550                "train" | "train_mode" => cfg.train_mode = parse_train_mode_token(val)?,
551                other => bail!("unknown mamba cfg key '{other}'"),
552            }
553        }
554        return Ok(MethodSpec::Online { cfg, policy });
555    }
556
557    let plain = PathBuf::from(base.trim());
558    if plain.exists() {
559        if policy.as_ref().and_then(|p| p.load_from.as_ref()).is_some() {
560            bail!("mamba method cannot use policy load_from together with file path");
561        }
562        return Ok(MethodSpec::File {
563            path: plain,
564            policy,
565        });
566    }
567
568    if base.contains(',') {
569        return Ok(MethodSpec::Online {
570            cfg: parse_cfg_positional(&base)?,
571            policy,
572        });
573    }
574
575    bail!(
576        "mamba method must be 'file:<path>', 'cfg:<k=v,...>', positional cfg CSV, or an existing model path"
577    );
578}
579
580/// Framing header for mambazip streams.
581#[derive(Debug, Clone)]
582pub struct Header {
583    /// Magic number.
584    pub magic: u32,
585    /// Version byte.
586    pub version: u8,
587    /// Coder type (0=AC,1=rANS).
588    pub coder: u8,
589    /// Original length.
590    pub original_len: u64,
591    /// CRC32 checksum of original data.
592    pub crc32: u32,
593}
594
595impl Header {
596    /// Header serialized size.
597    pub const SIZE: usize = 4 + 1 + 1 + 8 + 4;
598
599    /// Construct a new header.
600    pub fn new(coder: CoderType, original_len: u64, crc32: u32) -> Self {
601        Self {
602            magic: MAGIC,
603            version: VERSION,
604            coder: match coder {
605                CoderType::AC => 0,
606                CoderType::RANS => 1,
607            },
608            original_len,
609            crc32,
610        }
611    }
612
613    /// Write header.
614    pub fn write<W: Write>(&self, w: &mut W) -> Result<()> {
615        w.write_all(&self.magic.to_le_bytes())?;
616        w.write_all(&[self.version])?;
617        w.write_all(&[self.coder])?;
618        w.write_all(&self.original_len.to_le_bytes())?;
619        w.write_all(&self.crc32.to_le_bytes())?;
620        Ok(())
621    }
622
623    /// Read header.
624    pub fn read<R: Read>(r: &mut R) -> Result<Self> {
625        let mut buf4 = [0u8; 4];
626        let mut buf8 = [0u8; 8];
627        let mut buf1 = [0u8; 1];
628
629        r.read_exact(&mut buf4)?;
630        let magic = u32::from_le_bytes(buf4);
631        if magic != MAGIC {
632            bail!(
633                "invalid magic number: expected 0x{:08X}, got 0x{:08X}",
634                MAGIC,
635                magic
636            );
637        }
638
639        r.read_exact(&mut buf1)?;
640        let version = buf1[0];
641        if version > VERSION {
642            bail!(
643                "unsupported version: {} (max supported: {})",
644                version,
645                VERSION
646            );
647        }
648
649        r.read_exact(&mut buf1)?;
650        let coder = buf1[0];
651
652        r.read_exact(&mut buf8)?;
653        let original_len = u64::from_le_bytes(buf8);
654
655        r.read_exact(&mut buf4)?;
656        let crc32 = u32::from_le_bytes(buf4);
657
658        Ok(Self {
659            magic,
660            version,
661            coder,
662            original_len,
663            crc32,
664        })
665    }
666
667    /// Decode coder selection.
668    pub fn coder_type(&self) -> CoderType {
669        match self.coder {
670            0 => CoderType::AC,
671            _ => CoderType::RANS,
672        }
673    }
674}
675
676/// Compute CRC32 of data.
677pub fn crc32(data: &[u8]) -> u32 {
678    crate::coders::crc32(data)
679}
680
681struct CountingWriter {
682    n: u64,
683}
684
685impl CountingWriter {
686    #[inline]
687    fn new() -> Self {
688        Self { n: 0 }
689    }
690
691    #[inline]
692    fn bytes_written(&self) -> u64 {
693        self.n
694    }
695}
696
697impl Write for CountingWriter {
698    #[inline]
699    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
700        let n = buf.len();
701        self.n = self.n.saturating_add(n as u64);
702        Ok(n)
703    }
704
705    #[inline]
706    fn flush(&mut self) -> std::io::Result<()> {
707        Ok(())
708    }
709}
710
711/// Stateful compressor built around a Mamba model.
712pub struct Compressor {
713    /// Model weights.
714    pub model: Arc<Model>,
715    /// Recurrent state.
716    pub state: State,
717    /// Reusable scratch.
718    pub scratch: ScratchBuffers,
719    /// Reusable PDF buffer.
720    pub pdf_buffer: Vec<f64>,
721    cdf_buffer_ac: Vec<u32>,
722    ac_freq_buffer: Vec<i64>,
723    cdf_buffer_rans: Vec<u32>,
724    rans_freq_buffer: Vec<i64>,
725    online: Option<OnlineRuntime>,
726    source_model_path: Option<PathBuf>,
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732
733    fn temp_path(prefix: &str, ext: &str) -> PathBuf {
734        let now = std::time::SystemTime::now()
735            .duration_since(std::time::UNIX_EPOCH)
736            .unwrap_or_default()
737            .as_nanos();
738        std::env::temp_dir().join(format!("{prefix}_{}_{}.{}", std::process::id(), now, ext))
739    }
740
741    #[test]
742    fn parse_method_spec_accepts_cfg_and_positional() {
743        let named = parse_method_spec(
744            "cfg:hidden=64,layers=2,intermediate=96,state=8,conv=3,dt_rank=4,train=sgd,lr=0.01,stride=2;policy:schedule=0..100:infer",
745        )
746        .expect("named cfg");
747        match named {
748            MethodSpec::Online { cfg, .. } => {
749                assert_eq!(cfg.hidden, 64);
750                assert_eq!(cfg.layers, 2);
751                assert_eq!(cfg.intermediate, 96);
752                assert_eq!(cfg.state, 8);
753                assert_eq!(cfg.conv, 3);
754                assert_eq!(cfg.dt_rank, 4);
755                assert!(matches!(cfg.train_mode, OnlineTrainMode::Sgd));
756                assert_eq!(cfg.stride, 2);
757            }
758            _ => panic!("expected online cfg"),
759        }
760
761        let positional =
762            parse_method_spec("cfg:64,96,2,adam,123,0.001,3;policy:schedule=0..100:infer")
763                .expect("positional cfg");
764        match positional {
765            MethodSpec::Online { cfg, .. } => {
766                assert_eq!(cfg.hidden, 64);
767                assert_eq!(cfg.intermediate, 96);
768                assert_eq!(cfg.layers, 2);
769                assert!(matches!(cfg.train_mode, OnlineTrainMode::Adam));
770                assert_eq!(cfg.seed, 123);
771                assert_eq!(cfg.stride, 3);
772            }
773            _ => panic!("expected online cfg"),
774        }
775    }
776
777    #[test]
778    fn parse_method_spec_accepts_cfg_without_policy() {
779        let spec = parse_method_spec("cfg:hidden=64,layers=2,intermediate=96").expect("cfg");
780        match spec {
781            MethodSpec::Online { cfg, policy } => {
782                assert_eq!(cfg.hidden, 64);
783                assert_eq!(cfg.layers, 2);
784                assert_eq!(cfg.intermediate, 96);
785                assert!(policy.is_none());
786            }
787            _ => panic!("expected online cfg"),
788        }
789    }
790
791    #[test]
792    fn canonical_method_omits_policy_when_absent() {
793        let c = Compressor::new_from_method("cfg:hidden=64,layers=1,intermediate=96")
794            .expect("online model");
795        assert_eq!(
796            c.online_method_string(),
797            Some(
798                "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=0,train=none,lr=0.001,stride=1"
799            )
800        );
801    }
802
803    #[test]
804    fn export_reload_roundtrip_reproducible() {
805        let cfg = Config {
806            vocab_size: 256,
807            hidden_size: 32,
808            num_layers: 2,
809            inner_size: 48,
810            state_size: 8,
811            conv_kernel: 3,
812            dt_rank: 4,
813            layer_norm_eps: 1e-5,
814        };
815        let model = Arc::new(Model::new_random(cfg.clone(), 42).expect("random model"));
816        let mut c1 = Compressor::new_from_model(model);
817        c1.reset_and_prime();
818        let _ = c1.cross_entropy_from_current(b"mamba test").expect("score");
819
820        let base = std::env::temp_dir().join(format!(
821            "infotheory_mamba_rt_{}_{}.safetensors",
822            std::process::id(),
823            c1.tokens_processed()
824        ));
825        c1.export_online(&base).expect("export");
826
827        let mut c2 = Compressor::new(&base).expect("reload");
828        c2.reset_and_prime();
829        let h1 = c1.cross_entropy(b"abcabc").expect("h1");
830        let h2 = c2.cross_entropy(b"abcabc").expect("h2");
831        assert!((h1 - h2).abs() < 1e-9);
832
833        let _ = std::fs::remove_file(&base);
834        let _ = std::fs::remove_file(base.with_extension("json"));
835    }
836
837    #[test]
838    fn online_training_updates_lm_head_weights() {
839        let method = "cfg:hidden=64,layers=2,intermediate=96,state=8,conv=3,dt_rank=4,seed=11,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.9)";
840        let mut c = Compressor::new_from_method(method).expect("online model");
841        c.reset_and_prime();
842        let before = c.model.lm_head_weights()[0..64].to_vec();
843        let _ = c
844            .cross_entropy_from_current(b"online mamba weight update")
845            .expect("score");
846        let after = &c.model.lm_head_weights()[0..64];
847        let mut changed = false;
848        for i in 0..before.len() {
849            if before[i].to_bits() != after[i].to_bits() {
850                changed = true;
851                break;
852            }
853        }
854        assert!(
855            changed,
856            "expected LM-head weights to change under online training"
857        );
858    }
859
860    #[test]
861    fn online_training_scope_all_updates_non_head_params() {
862        let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=adam,lr=0.002,stride=1;policy:schedule=0..100:train(scope=mixer_proj,opt=adam,lr=0.002,stride=1,bptt=1,clip=0,momentum=0.9)";
863        let mut c = Compressor::new_from_method(method).expect("online model");
864        c.reset_and_prime();
865        let before_head = c.model.lm_head_weights()[0..64].to_vec();
866        let before_model = (*c.model).clone();
867        let _ = c
868            .cross_entropy_from_current(b"scope mixer_proj should train non-head mamba params")
869            .expect("score");
870        let after_head = &c.model.lm_head_weights()[0..64];
871        let mut head_unchanged = true;
872        for i in 0..before_head.len() {
873            if before_head[i].to_bits() != after_head[i].to_bits() {
874                head_unchanged = false;
875                break;
876            }
877        }
878        assert!(
879            head_unchanged,
880            "expected LM-head weights to remain unchanged under scope=mixer_proj"
881        );
882
883        // Compare logits from fresh state to detect non-head parameter movement.
884        let mut s1 = before_model.new_state();
885        let mut sc1 = ScratchBuffers::new(before_model.config());
886        let mut s2 = c.model.new_state();
887        let mut sc2 = ScratchBuffers::new(c.model.config());
888        let logits_before = before_model.forward(&mut sc1, 0, &mut s1);
889        let logits_after = c.model.forward(&mut sc2, 0, &mut s2);
890        let mut changed = false;
891        for idx in 0..logits_before.len().min(logits_after.len()) {
892            if logits_before[idx].to_bits() != logits_after[idx].to_bits() {
893                changed = true;
894                break;
895            }
896        }
897        assert!(
898            changed,
899            "expected non-head parameters to update under scope=mixer_proj"
900        );
901    }
902
903    #[test]
904    fn online_training_scope_all_bptt_gt_one_supported() {
905        let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=adam,lr=0.002,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.002,stride=1,bptt=2,clip=0,momentum=0.9)";
906        let mut c = Compressor::new_from_method(method).expect("online model");
907        let before_path = temp_path("mamba_tbptt_before", "safetensors");
908        let after_path = temp_path("mamba_tbptt_after", "safetensors");
909        c.model.save_safetensors(&before_path).expect("save before");
910        c.reset_and_prime();
911        let score = c
912            .cross_entropy_from_current(b"abcdef")
913            .expect("tbptt score");
914        assert!(score.is_finite());
915        c.model.save_safetensors(&after_path).expect("save after");
916        let before = std::fs::read(&before_path).expect("read before");
917        let after = std::fs::read(&after_path).expect("read after");
918        assert_ne!(
919            before, after,
920            "expected tbptt full training to update params"
921        );
922        std::fs::remove_file(before_path).ok();
923        std::fs::remove_file(after_path).ok();
924    }
925
926    #[test]
927    fn online_training_full_tbptt_updates_first_symbol_after_priming() {
928        let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=33,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=8,clip=0,momentum=0.9)";
929        let mut c = Compressor::new_from_method(method).expect("online model");
930        let before_path = temp_path("mamba_first_symbol_before", "safetensors");
931        let after_path = temp_path("mamba_first_symbol_after", "safetensors");
932        c.model.save_safetensors(&before_path).expect("save before");
933
934        c.reset_and_prime();
935        let score = c
936            .cross_entropy_from_current(b"a")
937            .expect("single-symbol score");
938        assert!(score.is_finite());
939        c.model.save_safetensors(&after_path).expect("save after");
940
941        let before = std::fs::read(&before_path).expect("read before");
942        let after = std::fs::read(&after_path).expect("read after");
943        assert_ne!(
944            before, after,
945            "expected first symbol update to flush at stream end"
946        );
947        std::fs::remove_file(before_path).ok();
948        std::fs::remove_file(after_path).ok();
949    }
950
951    #[test]
952    fn export_reload_roundtrip_preserves_full_adam_resume() {
953        let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=17,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
954        let data = b"mamba full adam export/reload deterministic continuation";
955        let mut c1 = Compressor::new_from_method(method).expect("online model");
956        let _ = c1.compress(data, CoderType::AC).expect("pre-train pass");
957
958        let model_path = std::env::temp_dir().join(format!(
959            "infotheory_mamba_full_adam_{}_{}.safetensors",
960            std::process::id(),
961            c1.tokens_processed()
962        ));
963        c1.export_online(&model_path).expect("export");
964        assert!(model_path.with_extension("opt.safetensors").exists());
965
966        let out1 = c1
967            .compress(data, CoderType::AC)
968            .expect("post-export compress");
969        let mut c2 = Compressor::new(&model_path).expect("reload");
970        let out2 = c2.compress(data, CoderType::AC).expect("reload compress");
971        assert_eq!(out1, out2, "full-adam resume must be bit-identical");
972
973        let _ = std::fs::remove_file(&model_path);
974        let _ = std::fs::remove_file(model_path.with_extension("json"));
975        let _ = std::fs::remove_file(model_path.with_extension("opt.safetensors"));
976    }
977
978    #[test]
979    fn clone_keeps_full_training_trace_mode() {
980        let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=18,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
981        let mut c = Compressor::new_from_method(method).expect("online model");
982        let mut cloned = c.clone();
983        cloned.reset_and_prime();
984        let _ = cloned
985            .cross_entropy_from_current(b"clone must preserve training-trace mode")
986            .expect("full-training step should succeed after clone");
987        c.reset_and_prime();
988        let _ = c
989            .cross_entropy_from_current(b"baseline run")
990            .expect("baseline full-training step");
991    }
992
993    #[test]
994    fn runtime_snapshot_restores_non_head_training_state() {
995        let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=19,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
996        let mut c = Compressor::new_from_method(method).expect("online model");
997        c.reset_and_prime();
998        c.absorb_chain(&[b"prior context".as_slice()])
999            .expect("prefix");
1000        let snap = c.snapshot_runtime();
1001
1002        let _ = c
1003            .cross_entropy_from_current(b"mutate model before restore")
1004            .expect("mutation pass");
1005
1006        c.restore_runtime(&snap);
1007        let score_a = c
1008            .cross_entropy_from_current(b"query after restore")
1009            .expect("score a");
1010
1011        c.restore_runtime(&snap);
1012        let score_b = c
1013            .cross_entropy_from_current(b"query after restore")
1014            .expect("score b");
1015
1016        assert!((score_a - score_b).abs() < 1e-12);
1017    }
1018
1019    #[test]
1020    fn clone_preserves_non_head_training_trace() {
1021        let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=20,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
1022        let mut c = Compressor::new_from_method(method).expect("online model");
1023        c.reset_and_prime();
1024        c.absorb_chain(&[b"clone trace prefix".as_slice()])
1025            .expect("prefix");
1026
1027        let mut cloned = c.clone();
1028        let score = cloned
1029            .cross_entropy_from_current(b"clone trace query")
1030            .expect("cloned full-training step");
1031        assert!(score.is_finite());
1032    }
1033}
1034
1035impl Clone for Compressor {
1036    fn clone(&self) -> Self {
1037        let mut cloned = Self::new_from_model(self.model.clone());
1038        cloned.state = self.state.clone();
1039        cloned.pdf_buffer.clone_from(&self.pdf_buffer);
1040        cloned.cdf_buffer_ac.clone_from(&self.cdf_buffer_ac);
1041        cloned.ac_freq_buffer.clone_from(&self.ac_freq_buffer);
1042        cloned.cdf_buffer_rans.clone_from(&self.cdf_buffer_rans);
1043        cloned.rans_freq_buffer.clone_from(&self.rans_freq_buffer);
1044        cloned.scratch = self.scratch.clone();
1045        cloned.online = self.online.clone();
1046        cloned.source_model_path = self.source_model_path.clone();
1047        cloned
1048    }
1049}
1050
1051impl Compressor {
1052    /// Create compressor by loading model path.
1053    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
1054        let model_path = model_path.as_ref();
1055        let model = Arc::new(Model::load(model_path)?);
1056        let mut c = Self::new_from_model(model);
1057        c.source_model_path = Some(model_path.to_path_buf());
1058        c.maybe_load_sidecar()?;
1059        Ok(c)
1060    }
1061
1062    /// Load model from path and wrap in Arc.
1063    pub fn load_model<P: AsRef<Path>>(model_path: P) -> Result<Arc<Model>> {
1064        Ok(Arc::new(Model::load(model_path)?))
1065    }
1066
1067    /// Create compressor from preloaded model.
1068    pub fn new_from_model(model: Arc<Model>) -> Self {
1069        let state = model.new_state();
1070        let vocab_size = model.config().vocab_size;
1071        let scratch = ScratchBuffers::new(model.config());
1072        Self {
1073            model,
1074            state,
1075            scratch,
1076            pdf_buffer: vec![0.0; vocab_size],
1077            cdf_buffer_ac: vec![0u32; vocab_size + 1],
1078            ac_freq_buffer: vec![0i64; vocab_size],
1079            cdf_buffer_rans: vec![0u32; vocab_size + 1],
1080            rans_freq_buffer: vec![0i64; vocab_size],
1081            online: None,
1082            source_model_path: None,
1083        }
1084    }
1085
1086    /// Create compressor from method string.
1087    pub fn new_from_method(method: &str) -> Result<Self> {
1088        match parse_method_spec(method)? {
1089            MethodSpec::File { path, policy } => {
1090                let mut c = Self::new(&path)?;
1091                if let Some(policy) = policy {
1092                    let canonical_method =
1093                        format!("file:{};policy:{}", path.display(), policy.canonical());
1094                    let hidden = c.model.config().hidden_size;
1095                    let mut online = c.online.take().unwrap_or_else(|| {
1096                        OnlineRuntime::new(
1097                            OnlineConfig::default(),
1098                            canonical_method.clone(),
1099                            Some(policy.clone()),
1100                            VOCAB_SIZE,
1101                            hidden,
1102                        )
1103                    });
1104                    online.canonical_method = canonical_method;
1105                    online.policy = Some(policy);
1106                    online.needs_full_trace = online
1107                        .policy
1108                        .as_ref()
1109                        .map(policy_needs_full_trace)
1110                        .unwrap_or(false);
1111                    online.full_tbptt = online.needs_full_trace.then(|| FullTbpttRuntime {
1112                        pending_input_token: None,
1113                        pending_input_pre_state: None,
1114                        segment_start_state: None,
1115                        steps: Vec::new(),
1116                        settings: None,
1117                    });
1118                    c.online = Some(online);
1119                    c.scratch.set_capture_train_trace(
1120                        c.online.as_ref().is_some_and(|o| o.needs_full_trace),
1121                    );
1122                }
1123                Ok(c)
1124            }
1125            MethodSpec::Online { cfg, policy } => {
1126                let mcfg = cfg.to_mamba_config()?;
1127                let model = if let Some(load_from) =
1128                    policy.as_ref().and_then(|p| p.load_from.as_ref())
1129                {
1130                    let loaded = Arc::new(Model::load(load_from)?);
1131                    let loaded_cfg = loaded.config();
1132                    let shape_ok = loaded_cfg.vocab_size == mcfg.vocab_size
1133                        && loaded_cfg.hidden_size == mcfg.hidden_size
1134                        && loaded_cfg.num_layers == mcfg.num_layers
1135                        && loaded_cfg.inner_size == mcfg.inner_size
1136                        && loaded_cfg.state_size == mcfg.state_size
1137                        && loaded_cfg.conv_kernel == mcfg.conv_kernel
1138                        && loaded_cfg.dt_rank == mcfg.dt_rank;
1139                    if !shape_ok {
1140                        bail!(
1141                            "mamba policy load_from shape mismatch with cfg (strict match required)"
1142                        );
1143                    }
1144                    loaded
1145                } else {
1146                    Arc::new(Model::new_random(mcfg, cfg.seed)?)
1147                };
1148                let mut c = Self::new_from_model(model);
1149                let mut canonical_method = cfg_to_method_string(&cfg);
1150                if let Some(policy) = policy.as_ref() {
1151                    canonical_method.push_str(";policy:");
1152                    canonical_method.push_str(&policy.canonical());
1153                }
1154                c.online = Some(OnlineRuntime::new(
1155                    cfg,
1156                    canonical_method,
1157                    policy,
1158                    VOCAB_SIZE,
1159                    c.model.config().hidden_size,
1160                ));
1161                c.scratch
1162                    .set_capture_train_trace(c.online.as_ref().is_some_and(|o| o.needs_full_trace));
1163                Ok(c)
1164            }
1165        }
1166    }
1167
1168    /// Reset state.
1169    pub fn reset(&mut self) {
1170        self.state.reset();
1171        self.clear_online_training_buffers();
1172    }
1173
1174    fn prepare_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1175        if let Some(online) = self.online.as_mut() {
1176            online.prepare_policy_stream(total_symbols)?;
1177        }
1178        Ok(())
1179    }
1180
1181    fn clear_online_training_buffers(&mut self) {
1182        if let Some(online) = self.online.as_mut()
1183            && let Some(tbptt) = online.full_tbptt.as_mut()
1184        {
1185            tbptt.pending_input_token = None;
1186            tbptt.pending_input_pre_state = None;
1187            tbptt.segment_start_state = None;
1188            tbptt.steps.clear();
1189            tbptt.settings = None;
1190        }
1191    }
1192
1193    fn forward_with_online_record(&mut self, token: u32) {
1194        if let Some(online) = self.online.as_mut()
1195            && let Some(tbptt) = online.full_tbptt.as_mut()
1196        {
1197            tbptt.pending_input_token = Some(token);
1198            tbptt.pending_input_pre_state = Some(self.state.clone());
1199        }
1200        let _ = self
1201            .model
1202            .forward(&mut self.scratch, token, &mut self.state);
1203    }
1204
1205    fn flush_full_tbptt_segment(&mut self) -> Result<()> {
1206        let extracted = {
1207            match self.online.as_mut() {
1208                Some(online) => match online.full_tbptt.as_mut() {
1209                    Some(tbptt) if !tbptt.steps.is_empty() => {
1210                        let settings = tbptt.settings.take().ok_or_else(|| {
1211                            anyhow::anyhow!("mamba full tbptt settings are missing")
1212                        })?;
1213                        let start_state = tbptt.segment_start_state.take().ok_or_else(|| {
1214                            anyhow::anyhow!("mamba full tbptt segment start is missing")
1215                        })?;
1216                        let steps = std::mem::take(&mut tbptt.steps);
1217                        let need_full_adam = matches!(settings.optimizer, OptimizerKind::Adam)
1218                            && settings.scope.trains_model_params()
1219                            && online.full_adam.is_none();
1220                        Some((settings, start_state, steps, need_full_adam))
1221                    }
1222                    _ => None,
1223                },
1224                None => None,
1225            }
1226        };
1227        let Some((settings, start_state, steps, need_full_adam)) = extracted else {
1228            return Ok(());
1229        };
1230
1231        if need_full_adam {
1232            let full_adam = self.model.new_full_adam_state();
1233            if let Some(online) = self.online.as_mut() {
1234                online.full_adam = Some(full_adam);
1235            }
1236        }
1237
1238        let segment_steps = steps
1239            .into_iter()
1240            .map(|step| (step.input_token, step.target_symbol, step.pdf))
1241            .collect::<Vec<_>>();
1242        let model = Arc::make_mut(&mut self.model);
1243        let Some(online) = self.online.as_mut() else {
1244            return Ok(());
1245        };
1246        model.online_train_segment_tbptt(
1247            &mut self.scratch,
1248            &start_state,
1249            &segment_steps,
1250            settings.scope,
1251            settings.optimizer,
1252            settings.lr,
1253            settings.clip,
1254            TBPTT_REPLAY_CHUNK,
1255            &mut online.adam_t,
1256            online.full_adam.as_mut(),
1257            if settings.scope.bias {
1258                Some(online.out_bias.as_mut_slice())
1259            } else {
1260                None
1261            },
1262            if settings.scope.bias {
1263                online.adam_m.as_deref_mut()
1264            } else {
1265                None
1266            },
1267            if settings.scope.bias {
1268                online.adam_v.as_deref_mut()
1269            } else {
1270                None
1271            },
1272            &mut self.state,
1273        )?;
1274        let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1275        Self::logits_to_pdf(self.scratch.logits(), bias, &mut self.pdf_buffer);
1276        Ok(())
1277    }
1278
1279    fn enqueue_full_tbptt_step(
1280        &mut self,
1281        settings: FullTrainSettings,
1282        target_symbol: u8,
1283        pdf: &[f64],
1284    ) -> Result<()> {
1285        let should_flush = {
1286            let Some(online) = self.online.as_mut() else {
1287                return Ok(());
1288            };
1289            let Some(tbptt) = online.full_tbptt.as_mut() else {
1290                bail!("mamba full-parameter online training requires trace-enabled tbptt runtime");
1291            };
1292            tbptt.settings.is_some_and(|prev| {
1293                !prev.matches(
1294                    settings.optimizer,
1295                    settings.lr,
1296                    settings.scope,
1297                    settings.bptt,
1298                    settings.clip,
1299                )
1300            }) && !tbptt.steps.is_empty()
1301        };
1302        if should_flush {
1303            self.flush_full_tbptt_segment()?;
1304        }
1305
1306        let flush_now = {
1307            let Some(online) = self.online.as_mut() else {
1308                return Ok(());
1309            };
1310            let Some(tbptt) = online.full_tbptt.as_mut() else {
1311                bail!("mamba full-parameter online training requires trace-enabled tbptt runtime");
1312            };
1313            let Some(input_token) = tbptt.pending_input_token.take() else {
1314                return Ok(());
1315            };
1316            let input_pre_state = tbptt
1317                .pending_input_pre_state
1318                .take()
1319                .ok_or_else(|| anyhow::anyhow!("mamba full tbptt pending pre-state is missing"))?;
1320            if tbptt.steps.is_empty() {
1321                tbptt.segment_start_state = Some(input_pre_state);
1322            }
1323            tbptt.settings = Some(settings);
1324            tbptt.steps.push(FullTbpttStep {
1325                input_token,
1326                target_symbol,
1327                pdf: pdf.to_vec(),
1328            });
1329            tbptt.steps.len() >= settings.bptt.max(1)
1330        };
1331        if flush_now {
1332            self.flush_full_tbptt_segment()?;
1333        }
1334        Ok(())
1335    }
1336
1337    /// Begin a policy stream with optional total symbol count.
1338    pub fn begin_online_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1339        self.finish_online_policy_stream()?;
1340        self.prepare_policy_stream(total_symbols)
1341    }
1342
1343    /// Flush any pending TBPTT segment while preserving current predictive state.
1344    pub fn finish_online_policy_stream(&mut self) -> Result<()> {
1345        self.flush_full_tbptt_segment()
1346    }
1347
1348    /// Reset hidden state and TBPTT bookkeeping for a fresh stream.
1349    pub fn restart_online_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1350        self.finish_online_policy_stream()?;
1351        self.state.reset();
1352        self.clear_online_training_buffers();
1353        self.prepare_policy_stream(total_symbols)
1354    }
1355
1356    /// Reset and compute initial distribution.
1357    pub fn reset_and_prime(&mut self) {
1358        self.state.reset();
1359        self.clear_online_training_buffers();
1360        self.refresh_current_pdf(0);
1361    }
1362
1363    /// Capture runtime snapshot.
1364    pub fn snapshot_runtime(&self) -> RuntimeSnapshot {
1365        RuntimeSnapshot {
1366            model: self.model.clone(),
1367            scratch: self.scratch.clone(),
1368            state: self.state.clone(),
1369            pdf_buffer: self.pdf_buffer.clone(),
1370            online: self.online.clone(),
1371        }
1372    }
1373
1374    /// Restore runtime snapshot.
1375    pub fn restore_runtime(&mut self, snapshot: &RuntimeSnapshot) {
1376        self.model = snapshot.model.clone();
1377        self.scratch = snapshot.scratch.clone();
1378        self.state = snapshot.state.clone();
1379        self.pdf_buffer.clone_from(&snapshot.pdf_buffer);
1380        self.online = snapshot.online.clone();
1381    }
1382
1383    /// Condition the model on a chain of prefixes.
1384    pub fn absorb_chain(&mut self, parts: &[&[u8]]) -> Result<()> {
1385        let total = parts
1386            .iter()
1387            .fold(0u64, |acc, part| acc.saturating_add(part.len() as u64));
1388        self.fit_chain(parts, Some(total))
1389    }
1390
1391    /// Cross entropy from current runtime state.
1392    pub fn cross_entropy_from_current(&mut self, data: &[u8]) -> Result<f64> {
1393        if data.is_empty() {
1394            return Ok(0.0);
1395        }
1396        self.begin_online_policy_stream(Some(data.len() as u64))?;
1397        let mut total_bits = 0.0;
1398        for &byte in data {
1399            let p = self.pdf_buffer[byte as usize].max(1e-300);
1400            total_bits -= p.log2();
1401            self.observe_symbol_from_current_pdf(byte)?;
1402        }
1403        self.finish_online_policy_stream()?;
1404        Ok(total_bits / (data.len() as f64))
1405    }
1406
1407    /// Fit on `fit_parts`, then reset stream state and score `data` without further adaptation.
1408    pub fn cross_entropy_frozen_plugin_chain(
1409        &mut self,
1410        fit_parts: &[&[u8]],
1411        data: &[u8],
1412    ) -> Result<f64> {
1413        if data.is_empty() {
1414            return Ok(0.0);
1415        }
1416        if !self.can_adapt_online() {
1417            return self.cross_entropy(data);
1418        }
1419        self.finish_online_policy_stream()?;
1420        self.reset_and_prime();
1421        let fit_total = fit_parts
1422            .iter()
1423            .fold(0u64, |acc, part| acc.saturating_add(part.len() as u64));
1424        self.fit_chain(fit_parts, Some(fit_total))?;
1425        self.reset_and_prime();
1426
1427        let mut total_bits = 0.0;
1428        for &byte in data {
1429            total_bits -= self.pdf_buffer[byte as usize].max(1e-300).log2();
1430            self.advance_inference_only(byte);
1431        }
1432        Ok(total_bits / (data.len() as f64))
1433    }
1434
1435    /// Whether online adaptation is enabled.
1436    pub fn is_online(&self) -> bool {
1437        self.online.is_some()
1438    }
1439
1440    /// Returns `true` when the current online configuration can actually adapt parameters.
1441    pub fn can_adapt_online(&self) -> bool {
1442        let Some(online) = &self.online else {
1443            return false;
1444        };
1445        match &online.policy {
1446            Some(policy) => llm_policy::policy_can_train(policy),
1447            None => !matches!(online.cfg.train_mode, OnlineTrainMode::None),
1448        }
1449    }
1450
1451    /// Tokens processed by online updater.
1452    pub fn tokens_processed(&self) -> u64 {
1453        self.online.as_ref().map_or(0, |s| s.tokens_processed)
1454    }
1455
1456    /// Canonical online method string.
1457    pub fn online_method_string(&self) -> Option<&str> {
1458        self.online.as_ref().map(|s| s.canonical_method.as_str())
1459    }
1460
1461    /// Vocabulary size.
1462    pub fn vocab_size(&self) -> usize {
1463        self.model.config().vocab_size
1464    }
1465
1466    /// Convert logits to PDF with online bias if active.
1467    pub fn online_apply_logits_bias(&self, logits: &[f32], pdf_out: &mut [f64]) {
1468        let bias = self.online.as_ref().map(|s| s.out_bias.as_slice());
1469        Self::logits_to_pdf(logits, bias, pdf_out);
1470    }
1471
1472    /// Convert logits + optional bias into stable normalized PDF.
1473    pub fn logits_to_pdf(logits: &[f32], bias: Option<&[f32]>, pdf_out: &mut [f64]) {
1474        softmax_pdf_floor_with_bias(logits, bias, pdf_out);
1475    }
1476
1477    #[inline]
1478    /// Forward one token and emit the resulting (optionally biased) PDF.
1479    pub fn forward_to_pdf(&mut self, token: u32, pdf_out: &mut [f64]) {
1480        self.forward_with_online_record(token);
1481        let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1482        Self::logits_to_pdf(self.scratch.logits(), bias, pdf_out);
1483    }
1484
1485    /// Snapshot online bias only.
1486    pub fn online_bias_snapshot(&self) -> Option<Vec<f32>> {
1487        self.online.as_ref().map(|o| o.out_bias.clone())
1488    }
1489
1490    #[inline]
1491    /// Borrow online output bias vector when online mode is active.
1492    pub fn online_bias_slice(&self) -> Option<&[f32]> {
1493        self.online.as_ref().map(|o| o.out_bias.as_slice())
1494    }
1495
1496    /// Apply one online update using external PDF.
1497    pub fn online_update_from_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1498        self.online_update_with_pdf(symbol, pdf)
1499    }
1500
1501    fn resolve_online_train_action(
1502        online: &mut OnlineRuntime,
1503    ) -> Result<(OptimizerKind, f32, u64, mamba1::TrainScopeMask, usize, f32)> {
1504        let mut optimizer = match online.cfg.train_mode {
1505            OnlineTrainMode::None => OptimizerKind::Sgd,
1506            OnlineTrainMode::Sgd => OptimizerKind::Sgd,
1507            OnlineTrainMode::Adam => OptimizerKind::Adam,
1508        };
1509        let mut lr = online.cfg.lr.max(0.0);
1510        let mut stride = online.cfg.stride.max(1) as u64;
1511        let mut scope = mamba1::TrainScopeMask::default();
1512        let default_train = !matches!(online.cfg.train_mode, OnlineTrainMode::None);
1513        scope.head = default_train;
1514        scope.bias = default_train;
1515        let mut bptt = 1usize;
1516        let mut clip = 0.0f32;
1517
1518        if let Some(action) = online.next_policy_action()? {
1519            match action {
1520                PolicyAction::Infer => {
1521                    scope = mamba1::TrainScopeMask::default();
1522                }
1523                PolicyAction::Train(train) => {
1524                    optimizer = train.optimizer;
1525                    lr = train.hyper.lr.max(0.0);
1526                    stride = train.hyper.stride.max(1) as u64;
1527                    bptt = train.hyper.bptt.max(1);
1528                    clip = train.hyper.clip.max(0.0);
1529                    if train.scope.all {
1530                        scope = mamba1::TrainScopeMask::all();
1531                    } else {
1532                        scope = mamba1::TrainScopeMask::default();
1533                        scope.embed = train.scope.contains("embed");
1534                        scope.layer_norm = train.scope.contains("layer_norm");
1535                        scope.mixer_conv = train.scope.contains("mixer_conv");
1536                        scope.mixer_ssm = train.scope.contains("mixer_ssm");
1537                        scope.mixer_proj = train.scope.contains("mixer_proj");
1538                        scope.head = train.scope.contains("head");
1539                        scope.bias = train.scope.contains("bias");
1540                    }
1541                }
1542            }
1543        }
1544        Ok((optimizer, lr, stride, scope, bptt, clip))
1545    }
1546
1547    #[inline]
1548    /// Update online state from `pdf`, then advance model state with `symbol`.
1549    pub fn observe_symbol_from_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1550        self.online_update_with_pdf(symbol, pdf)?;
1551        self.refresh_current_pdf(symbol as u32);
1552        Ok(())
1553    }
1554
1555    fn online_update_with_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1556        let (optimizer, lr, stride_hit, scope, bptt, clip) = {
1557            let Some(online) = self.online.as_mut() else {
1558                return Ok(());
1559            };
1560            online.tokens_processed = online.tokens_processed.saturating_add(1);
1561            let (optimizer, lr, stride, scope, bptt, clip) =
1562                Self::resolve_online_train_action(online)?;
1563            let mut stride_hit = false;
1564            if scope.trains_model_params() || scope.bias {
1565                online.policy_train_steps = online.policy_train_steps.saturating_add(1);
1566                stride_hit = stride <= 1 || (online.policy_train_steps % stride) == 0;
1567            }
1568            (optimizer, lr, stride_hit, scope, bptt, clip)
1569        };
1570
1571        if (!scope.trains_model_params() && !scope.bias) || !stride_hit || lr == 0.0 {
1572            self.flush_full_tbptt_segment()?;
1573            if let Some(online) = self.online.as_mut()
1574                && let Some(tbptt) = online.full_tbptt.as_mut()
1575            {
1576                tbptt.pending_input_token = None;
1577                tbptt.pending_input_pre_state = None;
1578            }
1579            return Ok(());
1580        }
1581
1582        if matches!(optimizer, OptimizerKind::Adam)
1583            && let Some(online) = self.online.as_mut()
1584            && scope.bias
1585            && (online.adam_m.is_none() || online.adam_v.is_none())
1586        {
1587            online.adam_m = Some(vec![0.0; online.out_bias.len()]);
1588            online.adam_v = Some(vec![0.0; online.out_bias.len()]);
1589        }
1590
1591        let trains_non_head = scope.embed
1592            || scope.layer_norm
1593            || scope.mixer_conv
1594            || scope.mixer_ssm
1595            || scope.mixer_proj;
1596        if trains_non_head && bptt > 1 {
1597            let settings = FullTrainSettings {
1598                optimizer,
1599                lr,
1600                scope,
1601                bptt,
1602                clip,
1603            };
1604            return self.enqueue_full_tbptt_step(settings, symbol, pdf);
1605        }
1606
1607        self.flush_full_tbptt_segment()?;
1608        if let Some(online) = self.online.as_mut()
1609            && let Some(tbptt) = online.full_tbptt.as_mut()
1610        {
1611            tbptt.pending_input_token = None;
1612            tbptt.pending_input_pre_state = None;
1613        }
1614        if scope.trains_model_params() {
1615            self.scratch.set_capture_train_trace(true);
1616        }
1617        if matches!(optimizer, OptimizerKind::Adam)
1618            && scope.trains_model_params()
1619            && self.online.as_ref().is_some_and(|o| o.full_adam.is_none())
1620        {
1621            let full_adam = self.model.as_ref().new_full_adam_state();
1622            if let Some(online) = self.online.as_mut()
1623                && online.full_adam.is_none()
1624            {
1625                online.full_adam = Some(full_adam);
1626            }
1627        }
1628
1629        let model = Arc::make_mut(&mut self.model);
1630        let Some(online) = self.online.as_mut() else {
1631            return Ok(());
1632        };
1633        let OnlineRuntime {
1634            out_bias,
1635            adam_m,
1636            adam_v,
1637            full_adam,
1638            adam_t,
1639            ..
1640        } = online;
1641        model.online_train_step_bptt1(
1642            &mut self.scratch,
1643            &self.state,
1644            symbol,
1645            pdf,
1646            scope,
1647            optimizer,
1648            lr,
1649            clip,
1650            adam_t,
1651            full_adam.as_mut(),
1652            if scope.bias {
1653                Some(out_bias.as_mut_slice())
1654            } else {
1655                None
1656            },
1657            if scope.bias {
1658                adam_m.as_deref_mut()
1659            } else {
1660                None
1661            },
1662            if scope.bias {
1663                adam_v.as_deref_mut()
1664            } else {
1665                None
1666            },
1667        )
1668    }
1669
1670    #[inline]
1671    fn refresh_current_pdf(&mut self, token: u32) {
1672        self.forward_with_online_record(token);
1673        let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1674        Self::logits_to_pdf(self.scratch.logits(), bias, &mut self.pdf_buffer);
1675    }
1676
1677    fn fit_chain(&mut self, parts: &[&[u8]], total_symbols: Option<u64>) -> Result<()> {
1678        self.begin_online_policy_stream(total_symbols)?;
1679        for part in parts {
1680            for &byte in *part {
1681                self.observe_symbol_from_current_pdf(byte)?;
1682            }
1683        }
1684        self.finish_online_policy_stream()?;
1685        Ok(())
1686    }
1687
1688    #[inline]
1689    fn advance_inference_only(&mut self, symbol: u8) {
1690        self.refresh_current_pdf(symbol as u32);
1691    }
1692
1693    fn online_update_from_current_pdf(&mut self, symbol: u8) -> Result<()> {
1694        let pdf_snapshot = self.pdf_buffer.clone();
1695        self.online_update_with_pdf(symbol, &pdf_snapshot)
1696    }
1697
1698    #[inline]
1699    /// Update online state using current internal PDF, then consume `symbol`.
1700    pub fn observe_symbol_from_current_pdf(&mut self, symbol: u8) -> Result<()> {
1701        self.online_update_from_current_pdf(symbol)?;
1702        self.refresh_current_pdf(symbol as u32);
1703        Ok(())
1704    }
1705
1706    /// Export model and online sidecar.
1707    pub fn export_online<P: AsRef<Path>>(&self, model_path: P) -> Result<()> {
1708        let model_path = model_path.as_ref();
1709        self.model.save_safetensors(model_path)?;
1710        let opt_sidecar = optimizer_sidecar_path(model_path);
1711
1712        let sidecar = model_path.with_extension("json");
1713        let meta = if let Some(online) = &self.online {
1714            if let Some(full_adam) = online.full_adam.as_ref() {
1715                self.model
1716                    .save_full_adam_safetensors(full_adam, &opt_sidecar)?;
1717            } else if opt_sidecar.exists() {
1718                let _ = fs::remove_file(&opt_sidecar);
1719            }
1720            let train_mode = match online.cfg.train_mode {
1721                OnlineTrainMode::None => "none",
1722                OnlineTrainMode::Sgd => "sgd",
1723                OnlineTrainMode::Adam => "adam",
1724            };
1725            json!({
1726                "version": 1,
1727                "method": online.canonical_method,
1728                "policy": online.policy.as_ref().map(LlmPolicy::canonical),
1729                "policy_cursor": online.policy_runtime.as_ref().map(PolicyRuntime::cursor).unwrap_or(0),
1730                "policy_stream_total": online.policy_stream_total,
1731                "policy_train_steps": online.policy_train_steps,
1732                "training_mode": train_mode,
1733                "tokens_processed": online.tokens_processed,
1734                "adam_t": online.adam_t,
1735                "has_full_adam": online.full_adam.is_some(),
1736                "config": {
1737                    "hidden": online.cfg.hidden,
1738                    "layers": online.cfg.layers,
1739                    "intermediate": online.cfg.intermediate,
1740                    "state": online.cfg.state,
1741                    "conv": online.cfg.conv,
1742                    "dt_rank": online.cfg.dt_rank,
1743                    "seed": online.cfg.seed,
1744                    "lr": online.cfg.lr,
1745                    "stride": online.cfg.stride.max(1),
1746                },
1747                "output_bias": online.out_bias,
1748                "adam_m": online.adam_m,
1749                "adam_v": online.adam_v,
1750                "lm_head_adam_m": online.lm_head_adam_m,
1751                "lm_head_adam_v": online.lm_head_adam_v,
1752            })
1753        } else {
1754            if opt_sidecar.exists() {
1755                let _ = fs::remove_file(&opt_sidecar);
1756            }
1757            json!({
1758                "version": 1,
1759                "method": format!("file:{}", model_path.display()),
1760                "training_mode": "none",
1761                "tokens_processed": 0,
1762            })
1763        };
1764
1765        fs::write(sidecar, serde_json::to_vec_pretty(&meta)?)?;
1766        Ok(())
1767    }
1768
1769    fn maybe_load_sidecar(&mut self) -> Result<()> {
1770        let Some(model_path) = &self.source_model_path else {
1771            return Ok(());
1772        };
1773        let sidecar = model_path.with_extension("json");
1774        if !sidecar.exists() {
1775            return Ok(());
1776        }
1777
1778        let raw = fs::read(&sidecar)?;
1779        let v: serde_json::Value = serde_json::from_slice(&raw)?;
1780        let parse_vec_f32 = |key: &str| -> Option<Vec<f32>> {
1781            v.get(key).and_then(|arr| arr.as_array()).map(|arr| {
1782                arr.iter()
1783                    .map(|x| x.as_f64().unwrap_or(0.0) as f32)
1784                    .collect::<Vec<f32>>()
1785            })
1786        };
1787        let output_bias = v
1788            .get("output_bias")
1789            .and_then(|arr| arr.as_array())
1790            .map(|arr| {
1791                arr.iter()
1792                    .map(|x| x.as_f64().unwrap_or(0.0) as f32)
1793                    .collect::<Vec<f32>>()
1794            });
1795
1796        let method = v
1797            .get("method")
1798            .and_then(|m| m.as_str())
1799            .map(|s| s.to_string())
1800            .unwrap_or_else(|| format!("file:{}", model_path.display()));
1801        let has_full_adam = v
1802            .get("has_full_adam")
1803            .and_then(|x| x.as_bool())
1804            .unwrap_or(false);
1805        let policy = v
1806            .get("policy")
1807            .and_then(|p| p.as_str())
1808            .and_then(|s| llm_policy::parse_policy_segment(s, MAMBA_TRAIN_SCOPES).ok());
1809        let tokens = v
1810            .get("tokens_processed")
1811            .and_then(|t| t.as_u64())
1812            .unwrap_or(0);
1813
1814        if let Some(mut out_bias) = output_bias {
1815            out_bias.resize(self.vocab_size(), 0.0);
1816            let mut cfg = OnlineConfig::default();
1817            if let Some(cfg_v) = v.get("config").and_then(|x| x.as_object()) {
1818                if let Some(x) = cfg_v.get("hidden").and_then(|x| x.as_u64()) {
1819                    cfg.hidden = x as usize;
1820                }
1821                if let Some(x) = cfg_v.get("layers").and_then(|x| x.as_u64()) {
1822                    cfg.layers = x as usize;
1823                }
1824                if let Some(x) = cfg_v.get("intermediate").and_then(|x| x.as_u64()) {
1825                    cfg.intermediate = x as usize;
1826                }
1827                if let Some(x) = cfg_v.get("state").and_then(|x| x.as_u64()) {
1828                    cfg.state = x as usize;
1829                }
1830                if let Some(x) = cfg_v.get("conv").and_then(|x| x.as_u64()) {
1831                    cfg.conv = x as usize;
1832                }
1833                if let Some(x) = cfg_v.get("dt_rank").and_then(|x| x.as_u64()) {
1834                    cfg.dt_rank = x as usize;
1835                }
1836                if let Some(x) = cfg_v.get("seed").and_then(|x| x.as_u64()) {
1837                    cfg.seed = x;
1838                }
1839                if let Some(x) = cfg_v.get("lr").and_then(|x| x.as_f64()) {
1840                    cfg.lr = x as f32;
1841                }
1842                if let Some(x) = cfg_v.get("stride").and_then(|x| x.as_u64()) {
1843                    cfg.stride = (x as usize).max(1);
1844                }
1845            }
1846            cfg.train_mode = v
1847                .get("training_mode")
1848                .and_then(|x| x.as_str())
1849                .and_then(|s| parse_train_mode_token(s).ok())
1850                .unwrap_or(OnlineTrainMode::None);
1851            let needs_full_trace = policy
1852                .as_ref()
1853                .map(policy_needs_full_trace)
1854                .unwrap_or(false);
1855
1856            self.online = Some(OnlineRuntime {
1857                cfg,
1858                canonical_method: method,
1859                policy,
1860                policy_runtime: None,
1861                needs_full_trace,
1862                policy_stream_total: v.get("policy_stream_total").and_then(|x| x.as_u64()),
1863                policy_train_steps: v
1864                    .get("policy_train_steps")
1865                    .and_then(|x| x.as_u64())
1866                    .unwrap_or(0),
1867                tokens_processed: tokens,
1868                out_bias,
1869                adam_m: parse_vec_f32("adam_m"),
1870                adam_v: parse_vec_f32("adam_v"),
1871                full_adam: None,
1872                lm_head_adam_m: parse_vec_f32("lm_head_adam_m"),
1873                lm_head_adam_v: parse_vec_f32("lm_head_adam_v"),
1874                adam_t: v.get("adam_t").and_then(|x| x.as_u64()).unwrap_or(0) as usize,
1875                full_tbptt: needs_full_trace.then(|| FullTbpttRuntime {
1876                    pending_input_token: None,
1877                    pending_input_pre_state: None,
1878                    segment_start_state: None,
1879                    steps: Vec::new(),
1880                    settings: None,
1881                }),
1882            });
1883            let opt_sidecar = optimizer_sidecar_path(model_path);
1884            if opt_sidecar.exists() {
1885                if let Some(online) = self.online.as_mut() {
1886                    online.full_adam = Some(self.model.load_full_adam_safetensors(&opt_sidecar)?);
1887                }
1888            } else if has_full_adam {
1889                bail!(
1890                    "missing optimizer sidecar '{}' required for exact online resume",
1891                    opt_sidecar.display()
1892                );
1893            }
1894            if let Some(cursor) = v.get("policy_cursor").and_then(|x| x.as_u64())
1895                && let Some(online) = self.online.as_mut()
1896                && online.policy.is_some()
1897            {
1898                let train_steps = online.policy_train_steps;
1899                online.prepare_policy_stream(online.policy_stream_total)?;
1900                online.policy_train_steps = train_steps;
1901                if let Some(rt) = online.policy_runtime.as_mut() {
1902                    rt.set_cursor(cursor);
1903                }
1904            }
1905            self.scratch
1906                .set_capture_train_trace(self.online.as_ref().is_some_and(|o| o.needs_full_trace));
1907        }
1908        Ok(())
1909    }
1910
1911    /// Compress into writer.
1912    pub fn compress_into<W: Write>(
1913        &mut self,
1914        data: &[u8],
1915        coder: CoderType,
1916        w: &mut W,
1917    ) -> Result<()> {
1918        self.restart_online_policy_stream(Some(data.len() as u64))?;
1919        let checksum = crc32(data);
1920        let header = Header::new(coder, data.len() as u64, checksum);
1921        header.write(w)?;
1922
1923        match coder {
1924            CoderType::AC => self.compress_ac_iter(data.iter().copied(), w)?,
1925            CoderType::RANS => self.compress_rans_iter(data.iter().copied(), w)?,
1926        }
1927        self.finish_online_policy_stream()?;
1928        Ok(())
1929    }
1930
1931    /// Compress a chain of byte slices into writer.
1932    pub fn compress_chain_into<W: Write>(
1933        &mut self,
1934        parts: &[&[u8]],
1935        coder: CoderType,
1936        w: &mut W,
1937    ) -> Result<()> {
1938        let mut total_len: u64 = 0;
1939        let mut hasher = crc32fast::Hasher::new();
1940        for p in parts {
1941            total_len = total_len.saturating_add(p.len() as u64);
1942            hasher.update(p);
1943        }
1944        self.restart_online_policy_stream(Some(total_len))?;
1945        let checksum = hasher.finalize();
1946        let header = Header::new(coder, total_len, checksum);
1947        header.write(w)?;
1948
1949        let it = parts.iter().flat_map(|p| p.iter().copied());
1950        match coder {
1951            CoderType::AC => self.compress_ac_iter(it, w)?,
1952            CoderType::RANS => self.compress_rans_iter(it, w)?,
1953        }
1954        self.finish_online_policy_stream()?;
1955        Ok(())
1956    }
1957
1958    /// Return compressed size without output allocation.
1959    pub fn compress_size(&mut self, data: &[u8], coder: CoderType) -> Result<u64> {
1960        let mut w = CountingWriter::new();
1961        self.compress_into(data, coder, &mut w)?;
1962        Ok(w.bytes_written())
1963    }
1964
1965    /// Return compressed size for chained inputs.
1966    pub fn compress_size_chain(&mut self, parts: &[&[u8]], coder: CoderType) -> Result<u64> {
1967        let mut w = CountingWriter::new();
1968        self.compress_chain_into(parts, coder, &mut w)?;
1969        Ok(w.bytes_written())
1970    }
1971
1972    /// Compress to bytes.
1973    pub fn compress(&mut self, data: &[u8], coder: CoderType) -> Result<Vec<u8>> {
1974        let mut out = Vec::new();
1975        self.compress_into(data, coder, &mut out)?;
1976        Ok(out)
1977    }
1978
1979    fn compress_ac_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
1980    where
1981        I: IntoIterator<Item = u8>,
1982    {
1983        let mut encoder = ArithmeticEncoder::new(output);
1984
1985        self.refresh_current_pdf(0);
1986
1987        for byte in data {
1988            quantize_pdf_to_cdf_with_buffer(
1989                &self.pdf_buffer,
1990                &mut self.cdf_buffer_ac,
1991                &mut self.ac_freq_buffer,
1992            );
1993            let sym = byte as usize;
1994            let lo = self.cdf_buffer_ac[sym] as u64;
1995            let hi = self.cdf_buffer_ac[sym + 1] as u64;
1996            encoder.encode_counts(lo, hi, CDF_TOTAL as u64)?;
1997            self.observe_symbol_from_current_pdf(byte)?;
1998        }
1999
2000        let _ = encoder.finish()?;
2001        Ok(())
2002    }
2003
2004    fn compress_rans_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
2005    where
2006        I: IntoIterator<Item = u8>,
2007    {
2008        let mut encoder = BlockedRansEncoder::new();
2009
2010        self.refresh_current_pdf(0);
2011
2012        for byte in data {
2013            quantize_pdf_to_rans_cdf_with_buffer(
2014                &self.pdf_buffer,
2015                &mut self.cdf_buffer_rans,
2016                &mut self.rans_freq_buffer,
2017            );
2018            let sym = byte as usize;
2019            let cdf = Cdf::new(
2020                self.cdf_buffer_rans[sym],
2021                self.cdf_buffer_rans[sym + 1],
2022                ANS_TOTAL,
2023            );
2024            encoder.encode(cdf);
2025            self.observe_symbol_from_current_pdf(byte)?;
2026        }
2027
2028        let blocks = encoder.finish();
2029        output.write_all(&(blocks.len() as u32).to_le_bytes())?;
2030        for block in &blocks {
2031            output.write_all(&(block.len() as u32).to_le_bytes())?;
2032            output.write_all(block)?;
2033        }
2034        Ok(())
2035    }
2036
2037    /// Decompress bytes.
2038    pub fn decompress(&mut self, data: &[u8]) -> Result<Vec<u8>> {
2039        let mut cursor = Cursor::new(data);
2040        let header = Header::read(&mut cursor)?;
2041
2042        self.restart_online_policy_stream(Some(header.original_len))?;
2043        let compressed = &data[Header::SIZE..];
2044        let result = match header.coder_type() {
2045            CoderType::AC => self.decompress_ac(compressed, header.original_len as usize)?,
2046            CoderType::RANS => self.decompress_rans(compressed, header.original_len as usize)?,
2047        };
2048
2049        let actual_crc = crc32(&result);
2050        if actual_crc != header.crc32 {
2051            bail!(
2052                "CRC32 mismatch: expected 0x{:08X}, got 0x{:08X}",
2053                header.crc32,
2054                actual_crc
2055            );
2056        }
2057        self.finish_online_policy_stream()?;
2058        Ok(result)
2059    }
2060
2061    fn decompress_ac(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
2062        let mut decoder = ArithmeticDecoder::new(compressed)?;
2063        let mut result = Vec::with_capacity(original_len);
2064
2065        self.refresh_current_pdf(0);
2066
2067        for _ in 0..original_len {
2068            quantize_pdf_to_cdf_with_buffer(
2069                &self.pdf_buffer,
2070                &mut self.cdf_buffer_ac,
2071                &mut self.ac_freq_buffer,
2072            );
2073            let sym = decoder.decode_symbol_counts(&self.cdf_buffer_ac, CDF_TOTAL)?;
2074            let byte = sym as u8;
2075            result.push(byte);
2076            self.observe_symbol_from_current_pdf(byte)?;
2077        }
2078
2079        Ok(result)
2080    }
2081
2082    fn decompress_rans(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
2083        if compressed.len() < 4 {
2084            bail!("rANS data too short");
2085        }
2086        let block_count =
2087            u32::from_le_bytes([compressed[0], compressed[1], compressed[2], compressed[3]])
2088                as usize;
2089
2090        let mut blocks = Vec::with_capacity(block_count);
2091        let mut pos = 4usize;
2092        for _ in 0..block_count {
2093            if pos + 4 > compressed.len() {
2094                bail!("truncated rANS block header");
2095            }
2096            let len = u32::from_le_bytes([
2097                compressed[pos],
2098                compressed[pos + 1],
2099                compressed[pos + 2],
2100                compressed[pos + 3],
2101            ]) as usize;
2102            pos += 4;
2103            if pos + len > compressed.len() {
2104                bail!("truncated rANS block data");
2105            }
2106            blocks.push(&compressed[pos..pos + len]);
2107            pos += len;
2108        }
2109
2110        let mut decoder = BlockedRansDecoder::new(blocks, original_len)?;
2111        let mut result = Vec::with_capacity(original_len);
2112
2113        self.refresh_current_pdf(0);
2114
2115        for _ in 0..original_len {
2116            quantize_pdf_to_rans_cdf_with_buffer(
2117                &self.pdf_buffer,
2118                &mut self.cdf_buffer_rans,
2119                &mut self.rans_freq_buffer,
2120            );
2121            let sym = decoder.decode(&self.cdf_buffer_rans)? as u8;
2122            result.push(sym);
2123            self.observe_symbol_from_current_pdf(sym)?;
2124        }
2125
2126        Ok(result)
2127    }
2128
2129    /// Cross entropy over whole sample.
2130    pub fn cross_entropy(&mut self, data: &[u8]) -> Result<f64> {
2131        self.reset_and_prime();
2132        self.cross_entropy_from_current(data)
2133    }
2134
2135    /// Cross entropy conditioned on prefix chain.
2136    pub fn cross_entropy_conditional_chain(
2137        &mut self,
2138        prefix_parts: &[&[u8]],
2139        data: &[u8],
2140    ) -> Result<f64> {
2141        if data.is_empty() {
2142            return Ok(0.0);
2143        }
2144        let prefix_len = prefix_parts
2145            .iter()
2146            .fold(0usize, |acc, p| acc.saturating_add(p.len()));
2147        self.finish_online_policy_stream()?;
2148        self.reset_and_prime();
2149        self.fit_chain(prefix_parts, Some((prefix_len + data.len()) as u64))?;
2150
2151        let mut total_bits = 0.0;
2152        for &byte in data {
2153            total_bits -= self.pdf_buffer[byte as usize].max(1e-300).log2();
2154            self.observe_symbol_from_current_pdf(byte)?;
2155        }
2156        self.finish_online_policy_stream()?;
2157        Ok(total_bits / (data.len() as f64))
2158    }
2159
2160    /// Cross entropy conditioned on one prefix.
2161    pub fn cross_entropy_conditional(&mut self, prefix: &[u8], data: &[u8]) -> Result<f64> {
2162        self.cross_entropy_conditional_chain(&[prefix], data)
2163    }
2164
2165    /// Symmetric aligned joint cross entropy using the better ordering.
2166    pub fn joint_cross_entropy_aligned_min(&mut self, x: &[u8], y: &[u8]) -> Result<f64> {
2167        let n = x.len().min(y.len());
2168        if n == 0 {
2169            return Ok(0.0);
2170        }
2171        let h_xy = self.joint_cross_entropy_aligned_order(x, y, false)?;
2172        let h_yx = self.joint_cross_entropy_aligned_order(x, y, true)?;
2173        Ok(h_xy.min(h_yx))
2174    }
2175
2176    fn joint_cross_entropy_aligned_order(&mut self, x: &[u8], y: &[u8], swap: bool) -> Result<f64> {
2177        let n = x.len().min(y.len());
2178        if n == 0 {
2179            return Ok(0.0);
2180        }
2181
2182        self.restart_online_policy_stream(Some((2 * n) as u64))?;
2183
2184        self.refresh_current_pdf(0);
2185
2186        let mut total_bits = 0.0;
2187        for idx in 0..n {
2188            let a = if swap { y[idx] } else { x[idx] };
2189            let b = if swap { x[idx] } else { y[idx] };
2190
2191            total_bits -= self.pdf_buffer[a as usize].max(1e-300).log2();
2192            self.observe_symbol_from_current_pdf(a)?;
2193
2194            total_bits -= self.pdf_buffer[b as usize].max(1e-300).log2();
2195            self.observe_symbol_from_current_pdf(b)?;
2196        }
2197
2198        self.finish_online_policy_stream()?;
2199        Ok(total_bits / (n as f64))
2200    }
2201}