1#![allow(clippy::items_after_test_module)]
2
3use anyhow::{Context, Result, bail};
6use serde_json::json;
7use std::fs;
8use std::io::{Cursor, Read, Write};
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11
12use crate::backends::llm_policy::{
13 self, LlmPolicy, OptimizerKind, PolicyAction, PolicyRuntime, split_method_policy_segments,
14};
15pub mod mamba1;
17
18pub use crate::coders;
20pub use crate::coders::CoderType;
22
23use crate::coders::{
24 ANS_TOTAL, ArithmeticDecoder, ArithmeticEncoder, BlockedRansDecoder, BlockedRansEncoder,
25 CDF_TOTAL, Cdf, quantize_pdf_to_cdf_with_buffer, quantize_pdf_to_rans_cdf_with_buffer,
26};
27
28pub use mamba1::Config;
30pub use mamba1::Model;
32pub use mamba1::ScratchBuffers;
34pub use mamba1::State;
36
37pub const MAGIC: u32 = 0x5a424d4d; pub const VERSION: u8 = 1;
41pub const VOCAB_SIZE: usize = 256;
43const TBPTT_REPLAY_CHUNK: usize = 32;
44const MAMBA_TRAIN_SCOPES: &[&str] = &[
45 "embed",
46 "layer_norm",
47 "mixer_conv",
48 "mixer_ssm",
49 "mixer_proj",
50 "head",
51 "bias",
52 "all",
53 "none",
54];
55
56#[inline]
57fn optimizer_sidecar_path(model_path: &Path) -> PathBuf {
58 model_path.with_extension("opt.safetensors")
59}
60
61#[derive(Clone, Copy, Debug, PartialEq, Eq)]
62pub enum OnlineTrainMode {
64 None,
66 Sgd,
68 Adam,
70}
71
72#[derive(Clone, Debug)]
73pub struct OnlineConfig {
75 pub hidden: usize,
77 pub layers: usize,
79 pub intermediate: usize,
81 pub state: usize,
83 pub conv: usize,
85 pub dt_rank: usize,
87 pub seed: u64,
89 pub train_mode: OnlineTrainMode,
91 pub lr: f32,
93 pub stride: usize,
95}
96
97impl Default for OnlineConfig {
98 fn default() -> Self {
99 Self {
100 hidden: 256,
101 layers: 6,
102 intermediate: 512,
103 state: 16,
104 conv: 4,
105 dt_rank: 16,
106 seed: 0,
107 train_mode: OnlineTrainMode::None,
108 lr: 0.001,
109 stride: 1,
110 }
111 }
112}
113
114impl OnlineConfig {
115 pub fn to_mamba_config(&self) -> Result<Config> {
117 let cfg = Config {
118 vocab_size: VOCAB_SIZE,
119 hidden_size: self.hidden.max(16),
120 num_layers: self.layers.max(1),
121 inner_size: self.intermediate.max(16),
122 state_size: self.state.max(1),
123 conv_kernel: self.conv.max(1),
124 dt_rank: self.dt_rank.max(1),
125 layer_norm_eps: 1e-5,
126 };
127 cfg.validate()?;
128 Ok(cfg)
129 }
130}
131
132#[derive(Clone, Debug)]
133pub enum MethodSpec {
135 File {
137 path: PathBuf,
139 policy: Option<LlmPolicy>,
141 },
142 Online {
144 cfg: OnlineConfig,
146 policy: Option<LlmPolicy>,
148 },
149}
150
151#[derive(Clone)]
152struct OnlineRuntime {
153 cfg: OnlineConfig,
154 canonical_method: String,
155 policy: Option<LlmPolicy>,
156 policy_runtime: Option<PolicyRuntime>,
157 needs_full_trace: bool,
158 policy_stream_total: Option<u64>,
159 policy_train_steps: u64,
160 tokens_processed: u64,
161 out_bias: Vec<f32>,
162 adam_m: Option<Vec<f32>>,
163 adam_v: Option<Vec<f32>>,
164 full_adam: Option<mamba1::FullAdamState>,
165 lm_head_adam_m: Option<Vec<f32>>,
166 lm_head_adam_v: Option<Vec<f32>>,
167 adam_t: usize,
168 full_tbptt: Option<FullTbpttRuntime>,
169}
170
171#[derive(Clone, Copy, Debug)]
172struct FullTrainSettings {
173 optimizer: OptimizerKind,
174 lr: f32,
175 scope: mamba1::TrainScopeMask,
176 bptt: usize,
177 clip: f32,
178}
179
180impl FullTrainSettings {
181 fn matches(
182 self,
183 optimizer: OptimizerKind,
184 lr: f32,
185 scope: mamba1::TrainScopeMask,
186 bptt: usize,
187 clip: f32,
188 ) -> bool {
189 self.optimizer == optimizer
190 && self.lr.to_bits() == lr.to_bits()
191 && self.scope.embed == scope.embed
192 && self.scope.layer_norm == scope.layer_norm
193 && self.scope.mixer_conv == scope.mixer_conv
194 && self.scope.mixer_ssm == scope.mixer_ssm
195 && self.scope.mixer_proj == scope.mixer_proj
196 && self.scope.head == scope.head
197 && self.scope.bias == scope.bias
198 && self.bptt == bptt
199 && self.clip.to_bits() == clip.to_bits()
200 }
201}
202
203#[derive(Clone)]
204struct FullTbpttStep {
205 input_token: u32,
206 target_symbol: u8,
207 pdf: Vec<f64>,
208}
209
210#[derive(Clone)]
211struct FullTbpttRuntime {
212 pending_input_token: Option<u32>,
213 pending_input_pre_state: Option<State>,
214 segment_start_state: Option<State>,
215 steps: Vec<FullTbpttStep>,
216 settings: Option<FullTrainSettings>,
217}
218
219#[derive(Clone)]
220pub struct RuntimeSnapshot {
222 model: Arc<Model>,
223 scratch: ScratchBuffers,
224 state: State,
225 pdf_buffer: Vec<f64>,
226 online: Option<OnlineRuntime>,
227}
228
229impl OnlineRuntime {
230 fn new(
231 cfg: OnlineConfig,
232 canonical_method: String,
233 policy: Option<LlmPolicy>,
234 vocab_size: usize,
235 hidden_size: usize,
236 ) -> Self {
237 let mut use_adam = matches!(cfg.train_mode, OnlineTrainMode::Adam);
238 if let Some(pol) = &policy {
239 use_adam = policy_uses_adam(pol) || use_adam;
240 }
241 let needs_full_trace = policy
242 .as_ref()
243 .map(policy_needs_full_trace)
244 .unwrap_or(false);
245 Self {
246 canonical_method,
247 cfg,
248 policy,
249 policy_runtime: None,
250 needs_full_trace,
251 policy_stream_total: None,
252 policy_train_steps: 0,
253 tokens_processed: 0,
254 out_bias: vec![0.0; vocab_size],
255 adam_m: use_adam.then(|| vec![0.0; vocab_size]),
256 adam_v: use_adam.then(|| vec![0.0; vocab_size]),
257 full_adam: None,
258 lm_head_adam_m: use_adam.then(|| vec![0.0; vocab_size * hidden_size]),
259 lm_head_adam_v: use_adam.then(|| vec![0.0; vocab_size * hidden_size]),
260 adam_t: 0,
261 full_tbptt: needs_full_trace.then(|| FullTbpttRuntime {
262 pending_input_token: None,
263 pending_input_pre_state: None,
264 segment_start_state: None,
265 steps: Vec::new(),
266 settings: None,
267 }),
268 }
269 }
270
271 fn prepare_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
272 self.policy_stream_total = total_symbols;
273 self.policy_train_steps = 0;
274 if let Some(tbptt) = self.full_tbptt.as_mut() {
275 tbptt.segment_start_state = None;
277 tbptt.steps.clear();
278 tbptt.settings = None;
279 }
280 self.policy_runtime = match &self.policy {
281 Some(p) => Some(PolicyRuntime::new(p.compile(total_symbols)?)),
282 None => None,
283 };
284 Ok(())
285 }
286
287 #[inline]
288 fn next_policy_action(&mut self) -> Result<Option<PolicyAction>> {
289 if self.policy.is_none() {
290 return Ok(None);
291 }
292 if self.policy_runtime.is_none() {
293 self.prepare_policy_stream(None)?;
294 }
295 Ok(self.policy_runtime.as_mut().map(PolicyRuntime::next_action))
296 }
297}
298
299fn policy_uses_adam(policy: &LlmPolicy) -> bool {
300 use llm_policy::ScheduleRule;
301 for rule in &policy.schedule {
302 match rule {
303 ScheduleRule::Interval(interval) => {
304 if let PolicyAction::Train(train) = &interval.action
305 && matches!(train.optimizer, OptimizerKind::Adam)
306 {
307 return true;
308 }
309 }
310 ScheduleRule::Repeat(repeat) => {
311 for seg in &repeat.pattern {
312 if let PolicyAction::Train(train) = &seg.action
313 && matches!(train.optimizer, OptimizerKind::Adam)
314 {
315 return true;
316 }
317 }
318 }
319 }
320 }
321 false
322}
323
324fn scope_needs_full_trace(scope: &llm_policy::TrainScopeSet) -> bool {
325 scope.all
326 || scope.contains("embed")
327 || scope.contains("layer_norm")
328 || scope.contains("mixer_conv")
329 || scope.contains("mixer_ssm")
330 || scope.contains("mixer_proj")
331}
332
333fn policy_needs_full_trace(policy: &LlmPolicy) -> bool {
334 use llm_policy::ScheduleRule;
335 for rule in &policy.schedule {
336 match rule {
337 ScheduleRule::Interval(interval) => {
338 if let PolicyAction::Train(train) = &interval.action
339 && scope_needs_full_trace(&train.scope)
340 {
341 return true;
342 }
343 }
344 ScheduleRule::Repeat(repeat) => {
345 for seg in &repeat.pattern {
346 if let PolicyAction::Train(train) = &seg.action
347 && scope_needs_full_trace(&train.scope)
348 {
349 return true;
350 }
351 }
352 }
353 }
354 }
355 false
356}
357
358fn cfg_to_method_string(cfg: &OnlineConfig) -> String {
359 let train = match cfg.train_mode {
360 OnlineTrainMode::None => "none",
361 OnlineTrainMode::Sgd => "sgd",
362 OnlineTrainMode::Adam => "adam",
363 };
364 format!(
365 "cfg:hidden={},layers={},intermediate={},state={},conv={},dt_rank={},seed={},train={},lr={},stride={}",
366 cfg.hidden,
367 cfg.layers,
368 cfg.intermediate,
369 cfg.state,
370 cfg.conv,
371 cfg.dt_rank,
372 cfg.seed,
373 train,
374 cfg.lr,
375 cfg.stride.max(1),
376 )
377}
378
379fn softmax_pdf_floor_with_bias(logits: &[f32], bias: Option<&[f32]>, pdf_out: &mut [f64]) {
380 debug_assert_eq!(logits.len(), pdf_out.len());
381 if let Some(b) = bias {
382 debug_assert_eq!(b.len(), logits.len());
383 }
384 if logits.is_empty() {
385 return;
386 }
387
388 let mut max_logit = f32::NEG_INFINITY;
389 if let Some(b) = bias {
390 for i in 0..logits.len() {
391 let z = logits[i] + b[i];
392 if z > max_logit {
393 max_logit = z;
394 }
395 }
396 } else {
397 for &z in logits {
398 if z > max_logit {
399 max_logit = z;
400 }
401 }
402 }
403
404 let mut sum = 0.0f64;
405 if let Some(b) = bias {
406 for i in 0..logits.len() {
407 let p = ((logits[i] + b[i] - max_logit) as f64).exp();
408 pdf_out[i] = p;
409 sum += p;
410 }
411 } else {
412 for i in 0..logits.len() {
413 let p = ((logits[i] - max_logit) as f64).exp();
414 pdf_out[i] = p;
415 sum += p;
416 }
417 }
418
419 let inv_sum = if sum.is_finite() && sum > 0.0 {
420 1.0 / sum
421 } else {
422 1.0 / (logits.len() as f64)
423 };
424
425 let floor = 1e-12f64;
426 let mut norm = 0.0f64;
427 for p in pdf_out.iter_mut() {
428 *p = (*p * inv_sum).max(floor);
429 norm += *p;
430 }
431 let inv_norm = if norm.is_finite() && norm > 0.0 {
432 1.0 / norm
433 } else {
434 1.0 / (logits.len() as f64)
435 };
436 for p in pdf_out.iter_mut() {
437 *p *= inv_norm;
438 }
439}
440
441fn parse_u64(v: &str, key: &str) -> Result<u64> {
442 v.parse::<u64>()
443 .with_context(|| format!("invalid integer value for '{key}': {v}"))
444}
445
446fn parse_usize(v: &str, key: &str) -> Result<usize> {
447 v.parse::<usize>()
448 .with_context(|| format!("invalid integer value for '{key}': {v}"))
449}
450
451fn parse_f32(v: &str, key: &str) -> Result<f32> {
452 v.parse::<f32>()
453 .with_context(|| format!("invalid float value for '{key}': {v}"))
454}
455
456fn parse_train_mode_token(v: &str) -> Result<OnlineTrainMode> {
457 let code = v.trim().to_ascii_lowercase();
458 match code.as_str() {
459 "0" | "none" | "off" => Ok(OnlineTrainMode::None),
460 "1" | "sgd" => Ok(OnlineTrainMode::Sgd),
461 "2" | "adam" => Ok(OnlineTrainMode::Adam),
462 other => bail!("unknown train mode '{other}'"),
463 }
464}
465
466fn parse_cfg_positional(csv: &str) -> Result<OnlineConfig> {
467 let vals: Vec<&str> = csv
468 .split(',')
469 .map(|s| s.trim())
470 .filter(|s| !s.is_empty())
471 .collect();
472 if vals.len() != 6 && vals.len() != 7 {
473 bail!(
474 "positional cfg format expects 6 or 7 values: hidden,intermediate,layers,train,seed,lr[,stride]"
475 );
476 }
477
478 Ok(OnlineConfig {
479 hidden: parse_usize(vals[0], "hidden")?,
480 intermediate: parse_usize(vals[1], "intermediate")?,
481 layers: parse_usize(vals[2], "layers")?,
482 train_mode: parse_train_mode_token(vals[3])?,
483 seed: parse_u64(vals[4], "seed")?,
484 lr: parse_f32(vals[5], "lr")?,
485 stride: if vals.len() == 7 {
486 parse_usize(vals[6], "stride")?
487 } else {
488 1
489 },
490 ..OnlineConfig::default()
491 })
492}
493
494pub fn parse_method_spec(method: &str) -> Result<MethodSpec> {
503 let (base, policy_segment) = split_method_policy_segments(method)?;
504 let parse_policy = |s: &str| llm_policy::parse_policy_segment(s, MAMBA_TRAIN_SCOPES);
505 let policy = policy_segment
506 .as_deref()
507 .map(parse_policy)
508 .transpose()
509 .context("failed to parse mamba policy segment")?;
510
511 if let Some(path) = base.strip_prefix("file:") {
512 let p = PathBuf::from(path.trim());
513 if p.as_os_str().is_empty() {
514 bail!("empty file path in mamba method");
515 }
516 if policy.as_ref().and_then(|p| p.load_from.as_ref()).is_some() {
517 bail!("mamba method cannot use policy load_from together with file:<path>");
518 }
519 return Ok(MethodSpec::File { path: p, policy });
520 }
521
522 if let Some(cfg_s) = base.strip_prefix("cfg:") {
523 if !cfg_s.contains('=') {
524 return Ok(MethodSpec::Online {
525 cfg: parse_cfg_positional(cfg_s)?,
526 policy,
527 });
528 }
529 let mut cfg = OnlineConfig::default();
530 for pair in cfg_s.split(',') {
531 let pair = pair.trim();
532 if pair.is_empty() {
533 continue;
534 }
535 let (k, v) = pair
536 .split_once('=')
537 .with_context(|| format!("invalid cfg key/value pair '{pair}'"))?;
538 let key = k.trim().to_ascii_lowercase();
539 let val = v.trim();
540 match key.as_str() {
541 "hidden" => cfg.hidden = parse_usize(val, "hidden")?,
542 "layers" => cfg.layers = parse_usize(val, "layers")?,
543 "intermediate" => cfg.intermediate = parse_usize(val, "intermediate")?,
544 "state" | "d_state" => cfg.state = parse_usize(val, "state")?,
545 "conv" | "d_conv" => cfg.conv = parse_usize(val, "conv")?,
546 "dt_rank" => cfg.dt_rank = parse_usize(val, "dt_rank")?,
547 "seed" => cfg.seed = parse_u64(val, "seed")?,
548 "lr" => cfg.lr = parse_f32(val, "lr")?,
549 "stride" => cfg.stride = parse_usize(val, "stride")?,
550 "train" | "train_mode" => cfg.train_mode = parse_train_mode_token(val)?,
551 other => bail!("unknown mamba cfg key '{other}'"),
552 }
553 }
554 return Ok(MethodSpec::Online { cfg, policy });
555 }
556
557 let plain = PathBuf::from(base.trim());
558 if plain.exists() {
559 if policy.as_ref().and_then(|p| p.load_from.as_ref()).is_some() {
560 bail!("mamba method cannot use policy load_from together with file path");
561 }
562 return Ok(MethodSpec::File {
563 path: plain,
564 policy,
565 });
566 }
567
568 if base.contains(',') {
569 return Ok(MethodSpec::Online {
570 cfg: parse_cfg_positional(&base)?,
571 policy,
572 });
573 }
574
575 bail!(
576 "mamba method must be 'file:<path>', 'cfg:<k=v,...>', positional cfg CSV, or an existing model path"
577 );
578}
579
580#[derive(Debug, Clone)]
582pub struct Header {
583 pub magic: u32,
585 pub version: u8,
587 pub coder: u8,
589 pub original_len: u64,
591 pub crc32: u32,
593}
594
595impl Header {
596 pub const SIZE: usize = 4 + 1 + 1 + 8 + 4;
598
599 pub fn new(coder: CoderType, original_len: u64, crc32: u32) -> Self {
601 Self {
602 magic: MAGIC,
603 version: VERSION,
604 coder: match coder {
605 CoderType::AC => 0,
606 CoderType::RANS => 1,
607 },
608 original_len,
609 crc32,
610 }
611 }
612
613 pub fn write<W: Write>(&self, w: &mut W) -> Result<()> {
615 w.write_all(&self.magic.to_le_bytes())?;
616 w.write_all(&[self.version])?;
617 w.write_all(&[self.coder])?;
618 w.write_all(&self.original_len.to_le_bytes())?;
619 w.write_all(&self.crc32.to_le_bytes())?;
620 Ok(())
621 }
622
623 pub fn read<R: Read>(r: &mut R) -> Result<Self> {
625 let mut buf4 = [0u8; 4];
626 let mut buf8 = [0u8; 8];
627 let mut buf1 = [0u8; 1];
628
629 r.read_exact(&mut buf4)?;
630 let magic = u32::from_le_bytes(buf4);
631 if magic != MAGIC {
632 bail!(
633 "invalid magic number: expected 0x{:08X}, got 0x{:08X}",
634 MAGIC,
635 magic
636 );
637 }
638
639 r.read_exact(&mut buf1)?;
640 let version = buf1[0];
641 if version > VERSION {
642 bail!(
643 "unsupported version: {} (max supported: {})",
644 version,
645 VERSION
646 );
647 }
648
649 r.read_exact(&mut buf1)?;
650 let coder = buf1[0];
651
652 r.read_exact(&mut buf8)?;
653 let original_len = u64::from_le_bytes(buf8);
654
655 r.read_exact(&mut buf4)?;
656 let crc32 = u32::from_le_bytes(buf4);
657
658 Ok(Self {
659 magic,
660 version,
661 coder,
662 original_len,
663 crc32,
664 })
665 }
666
667 pub fn coder_type(&self) -> CoderType {
669 match self.coder {
670 0 => CoderType::AC,
671 _ => CoderType::RANS,
672 }
673 }
674}
675
676pub fn crc32(data: &[u8]) -> u32 {
678 crate::coders::crc32(data)
679}
680
681struct CountingWriter {
682 n: u64,
683}
684
685impl CountingWriter {
686 #[inline]
687 fn new() -> Self {
688 Self { n: 0 }
689 }
690
691 #[inline]
692 fn bytes_written(&self) -> u64 {
693 self.n
694 }
695}
696
697impl Write for CountingWriter {
698 #[inline]
699 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
700 let n = buf.len();
701 self.n = self.n.saturating_add(n as u64);
702 Ok(n)
703 }
704
705 #[inline]
706 fn flush(&mut self) -> std::io::Result<()> {
707 Ok(())
708 }
709}
710
711pub struct Compressor {
713 pub model: Arc<Model>,
715 pub state: State,
717 pub scratch: ScratchBuffers,
719 pub pdf_buffer: Vec<f64>,
721 cdf_buffer_ac: Vec<u32>,
722 ac_freq_buffer: Vec<i64>,
723 cdf_buffer_rans: Vec<u32>,
724 rans_freq_buffer: Vec<i64>,
725 online: Option<OnlineRuntime>,
726 source_model_path: Option<PathBuf>,
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732
733 fn temp_path(prefix: &str, ext: &str) -> PathBuf {
734 let now = std::time::SystemTime::now()
735 .duration_since(std::time::UNIX_EPOCH)
736 .unwrap_or_default()
737 .as_nanos();
738 std::env::temp_dir().join(format!("{prefix}_{}_{}.{}", std::process::id(), now, ext))
739 }
740
741 #[test]
742 fn parse_method_spec_accepts_cfg_and_positional() {
743 let named = parse_method_spec(
744 "cfg:hidden=64,layers=2,intermediate=96,state=8,conv=3,dt_rank=4,train=sgd,lr=0.01,stride=2;policy:schedule=0..100:infer",
745 )
746 .expect("named cfg");
747 match named {
748 MethodSpec::Online { cfg, .. } => {
749 assert_eq!(cfg.hidden, 64);
750 assert_eq!(cfg.layers, 2);
751 assert_eq!(cfg.intermediate, 96);
752 assert_eq!(cfg.state, 8);
753 assert_eq!(cfg.conv, 3);
754 assert_eq!(cfg.dt_rank, 4);
755 assert!(matches!(cfg.train_mode, OnlineTrainMode::Sgd));
756 assert_eq!(cfg.stride, 2);
757 }
758 _ => panic!("expected online cfg"),
759 }
760
761 let positional =
762 parse_method_spec("cfg:64,96,2,adam,123,0.001,3;policy:schedule=0..100:infer")
763 .expect("positional cfg");
764 match positional {
765 MethodSpec::Online { cfg, .. } => {
766 assert_eq!(cfg.hidden, 64);
767 assert_eq!(cfg.intermediate, 96);
768 assert_eq!(cfg.layers, 2);
769 assert!(matches!(cfg.train_mode, OnlineTrainMode::Adam));
770 assert_eq!(cfg.seed, 123);
771 assert_eq!(cfg.stride, 3);
772 }
773 _ => panic!("expected online cfg"),
774 }
775 }
776
777 #[test]
778 fn parse_method_spec_accepts_cfg_without_policy() {
779 let spec = parse_method_spec("cfg:hidden=64,layers=2,intermediate=96").expect("cfg");
780 match spec {
781 MethodSpec::Online { cfg, policy } => {
782 assert_eq!(cfg.hidden, 64);
783 assert_eq!(cfg.layers, 2);
784 assert_eq!(cfg.intermediate, 96);
785 assert!(policy.is_none());
786 }
787 _ => panic!("expected online cfg"),
788 }
789 }
790
791 #[test]
792 fn canonical_method_omits_policy_when_absent() {
793 let c = Compressor::new_from_method("cfg:hidden=64,layers=1,intermediate=96")
794 .expect("online model");
795 assert_eq!(
796 c.online_method_string(),
797 Some(
798 "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=0,train=none,lr=0.001,stride=1"
799 )
800 );
801 }
802
803 #[test]
804 fn export_reload_roundtrip_reproducible() {
805 let cfg = Config {
806 vocab_size: 256,
807 hidden_size: 32,
808 num_layers: 2,
809 inner_size: 48,
810 state_size: 8,
811 conv_kernel: 3,
812 dt_rank: 4,
813 layer_norm_eps: 1e-5,
814 };
815 let model = Arc::new(Model::new_random(cfg.clone(), 42).expect("random model"));
816 let mut c1 = Compressor::new_from_model(model);
817 c1.reset_and_prime();
818 let _ = c1.cross_entropy_from_current(b"mamba test").expect("score");
819
820 let base = std::env::temp_dir().join(format!(
821 "infotheory_mamba_rt_{}_{}.safetensors",
822 std::process::id(),
823 c1.tokens_processed()
824 ));
825 c1.export_online(&base).expect("export");
826
827 let mut c2 = Compressor::new(&base).expect("reload");
828 c2.reset_and_prime();
829 let h1 = c1.cross_entropy(b"abcabc").expect("h1");
830 let h2 = c2.cross_entropy(b"abcabc").expect("h2");
831 assert!((h1 - h2).abs() < 1e-9);
832
833 let _ = std::fs::remove_file(&base);
834 let _ = std::fs::remove_file(base.with_extension("json"));
835 }
836
837 #[test]
838 fn online_training_updates_lm_head_weights() {
839 let method = "cfg:hidden=64,layers=2,intermediate=96,state=8,conv=3,dt_rank=4,seed=11,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.9)";
840 let mut c = Compressor::new_from_method(method).expect("online model");
841 c.reset_and_prime();
842 let before = c.model.lm_head_weights()[0..64].to_vec();
843 let _ = c
844 .cross_entropy_from_current(b"online mamba weight update")
845 .expect("score");
846 let after = &c.model.lm_head_weights()[0..64];
847 let mut changed = false;
848 for i in 0..before.len() {
849 if before[i].to_bits() != after[i].to_bits() {
850 changed = true;
851 break;
852 }
853 }
854 assert!(
855 changed,
856 "expected LM-head weights to change under online training"
857 );
858 }
859
860 #[test]
861 fn online_training_scope_all_updates_non_head_params() {
862 let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=adam,lr=0.002,stride=1;policy:schedule=0..100:train(scope=mixer_proj,opt=adam,lr=0.002,stride=1,bptt=1,clip=0,momentum=0.9)";
863 let mut c = Compressor::new_from_method(method).expect("online model");
864 c.reset_and_prime();
865 let before_head = c.model.lm_head_weights()[0..64].to_vec();
866 let before_model = (*c.model).clone();
867 let _ = c
868 .cross_entropy_from_current(b"scope mixer_proj should train non-head mamba params")
869 .expect("score");
870 let after_head = &c.model.lm_head_weights()[0..64];
871 let mut head_unchanged = true;
872 for i in 0..before_head.len() {
873 if before_head[i].to_bits() != after_head[i].to_bits() {
874 head_unchanged = false;
875 break;
876 }
877 }
878 assert!(
879 head_unchanged,
880 "expected LM-head weights to remain unchanged under scope=mixer_proj"
881 );
882
883 let mut s1 = before_model.new_state();
885 let mut sc1 = ScratchBuffers::new(before_model.config());
886 let mut s2 = c.model.new_state();
887 let mut sc2 = ScratchBuffers::new(c.model.config());
888 let logits_before = before_model.forward(&mut sc1, 0, &mut s1);
889 let logits_after = c.model.forward(&mut sc2, 0, &mut s2);
890 let mut changed = false;
891 for idx in 0..logits_before.len().min(logits_after.len()) {
892 if logits_before[idx].to_bits() != logits_after[idx].to_bits() {
893 changed = true;
894 break;
895 }
896 }
897 assert!(
898 changed,
899 "expected non-head parameters to update under scope=mixer_proj"
900 );
901 }
902
903 #[test]
904 fn online_training_scope_all_bptt_gt_one_supported() {
905 let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=adam,lr=0.002,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.002,stride=1,bptt=2,clip=0,momentum=0.9)";
906 let mut c = Compressor::new_from_method(method).expect("online model");
907 let before_path = temp_path("mamba_tbptt_before", "safetensors");
908 let after_path = temp_path("mamba_tbptt_after", "safetensors");
909 c.model.save_safetensors(&before_path).expect("save before");
910 c.reset_and_prime();
911 let score = c
912 .cross_entropy_from_current(b"abcdef")
913 .expect("tbptt score");
914 assert!(score.is_finite());
915 c.model.save_safetensors(&after_path).expect("save after");
916 let before = std::fs::read(&before_path).expect("read before");
917 let after = std::fs::read(&after_path).expect("read after");
918 assert_ne!(
919 before, after,
920 "expected tbptt full training to update params"
921 );
922 std::fs::remove_file(before_path).ok();
923 std::fs::remove_file(after_path).ok();
924 }
925
926 #[test]
927 fn online_training_full_tbptt_updates_first_symbol_after_priming() {
928 let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=33,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=8,clip=0,momentum=0.9)";
929 let mut c = Compressor::new_from_method(method).expect("online model");
930 let before_path = temp_path("mamba_first_symbol_before", "safetensors");
931 let after_path = temp_path("mamba_first_symbol_after", "safetensors");
932 c.model.save_safetensors(&before_path).expect("save before");
933
934 c.reset_and_prime();
935 let score = c
936 .cross_entropy_from_current(b"a")
937 .expect("single-symbol score");
938 assert!(score.is_finite());
939 c.model.save_safetensors(&after_path).expect("save after");
940
941 let before = std::fs::read(&before_path).expect("read before");
942 let after = std::fs::read(&after_path).expect("read after");
943 assert_ne!(
944 before, after,
945 "expected first symbol update to flush at stream end"
946 );
947 std::fs::remove_file(before_path).ok();
948 std::fs::remove_file(after_path).ok();
949 }
950
951 #[test]
952 fn export_reload_roundtrip_preserves_full_adam_resume() {
953 let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=17,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
954 let data = b"mamba full adam export/reload deterministic continuation";
955 let mut c1 = Compressor::new_from_method(method).expect("online model");
956 let _ = c1.compress(data, CoderType::AC).expect("pre-train pass");
957
958 let model_path = std::env::temp_dir().join(format!(
959 "infotheory_mamba_full_adam_{}_{}.safetensors",
960 std::process::id(),
961 c1.tokens_processed()
962 ));
963 c1.export_online(&model_path).expect("export");
964 assert!(model_path.with_extension("opt.safetensors").exists());
965
966 let out1 = c1
967 .compress(data, CoderType::AC)
968 .expect("post-export compress");
969 let mut c2 = Compressor::new(&model_path).expect("reload");
970 let out2 = c2.compress(data, CoderType::AC).expect("reload compress");
971 assert_eq!(out1, out2, "full-adam resume must be bit-identical");
972
973 let _ = std::fs::remove_file(&model_path);
974 let _ = std::fs::remove_file(model_path.with_extension("json"));
975 let _ = std::fs::remove_file(model_path.with_extension("opt.safetensors"));
976 }
977
978 #[test]
979 fn clone_keeps_full_training_trace_mode() {
980 let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=18,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
981 let mut c = Compressor::new_from_method(method).expect("online model");
982 let mut cloned = c.clone();
983 cloned.reset_and_prime();
984 let _ = cloned
985 .cross_entropy_from_current(b"clone must preserve training-trace mode")
986 .expect("full-training step should succeed after clone");
987 c.reset_and_prime();
988 let _ = c
989 .cross_entropy_from_current(b"baseline run")
990 .expect("baseline full-training step");
991 }
992
993 #[test]
994 fn runtime_snapshot_restores_non_head_training_state() {
995 let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=19,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
996 let mut c = Compressor::new_from_method(method).expect("online model");
997 c.reset_and_prime();
998 c.absorb_chain(&[b"prior context".as_slice()])
999 .expect("prefix");
1000 let snap = c.snapshot_runtime();
1001
1002 let _ = c
1003 .cross_entropy_from_current(b"mutate model before restore")
1004 .expect("mutation pass");
1005
1006 c.restore_runtime(&snap);
1007 let score_a = c
1008 .cross_entropy_from_current(b"query after restore")
1009 .expect("score a");
1010
1011 c.restore_runtime(&snap);
1012 let score_b = c
1013 .cross_entropy_from_current(b"query after restore")
1014 .expect("score b");
1015
1016 assert!((score_a - score_b).abs() < 1e-12);
1017 }
1018
1019 #[test]
1020 fn clone_preserves_non_head_training_trace() {
1021 let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=20,train=adam,lr=0.0008,stride=1;policy:schedule=0..100:train(scope=all,opt=adam,lr=0.0008,stride=1,bptt=1,clip=0,momentum=0.9)";
1022 let mut c = Compressor::new_from_method(method).expect("online model");
1023 c.reset_and_prime();
1024 c.absorb_chain(&[b"clone trace prefix".as_slice()])
1025 .expect("prefix");
1026
1027 let mut cloned = c.clone();
1028 let score = cloned
1029 .cross_entropy_from_current(b"clone trace query")
1030 .expect("cloned full-training step");
1031 assert!(score.is_finite());
1032 }
1033}
1034
1035impl Clone for Compressor {
1036 fn clone(&self) -> Self {
1037 let mut cloned = Self::new_from_model(self.model.clone());
1038 cloned.state = self.state.clone();
1039 cloned.pdf_buffer.clone_from(&self.pdf_buffer);
1040 cloned.cdf_buffer_ac.clone_from(&self.cdf_buffer_ac);
1041 cloned.ac_freq_buffer.clone_from(&self.ac_freq_buffer);
1042 cloned.cdf_buffer_rans.clone_from(&self.cdf_buffer_rans);
1043 cloned.rans_freq_buffer.clone_from(&self.rans_freq_buffer);
1044 cloned.scratch = self.scratch.clone();
1045 cloned.online = self.online.clone();
1046 cloned.source_model_path = self.source_model_path.clone();
1047 cloned
1048 }
1049}
1050
1051impl Compressor {
1052 pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
1054 let model_path = model_path.as_ref();
1055 let model = Arc::new(Model::load(model_path)?);
1056 let mut c = Self::new_from_model(model);
1057 c.source_model_path = Some(model_path.to_path_buf());
1058 c.maybe_load_sidecar()?;
1059 Ok(c)
1060 }
1061
1062 pub fn load_model<P: AsRef<Path>>(model_path: P) -> Result<Arc<Model>> {
1064 Ok(Arc::new(Model::load(model_path)?))
1065 }
1066
1067 pub fn new_from_model(model: Arc<Model>) -> Self {
1069 let state = model.new_state();
1070 let vocab_size = model.config().vocab_size;
1071 let scratch = ScratchBuffers::new(model.config());
1072 Self {
1073 model,
1074 state,
1075 scratch,
1076 pdf_buffer: vec![0.0; vocab_size],
1077 cdf_buffer_ac: vec![0u32; vocab_size + 1],
1078 ac_freq_buffer: vec![0i64; vocab_size],
1079 cdf_buffer_rans: vec![0u32; vocab_size + 1],
1080 rans_freq_buffer: vec![0i64; vocab_size],
1081 online: None,
1082 source_model_path: None,
1083 }
1084 }
1085
1086 pub fn new_from_method(method: &str) -> Result<Self> {
1088 match parse_method_spec(method)? {
1089 MethodSpec::File { path, policy } => {
1090 let mut c = Self::new(&path)?;
1091 if let Some(policy) = policy {
1092 let canonical_method =
1093 format!("file:{};policy:{}", path.display(), policy.canonical());
1094 let hidden = c.model.config().hidden_size;
1095 let mut online = c.online.take().unwrap_or_else(|| {
1096 OnlineRuntime::new(
1097 OnlineConfig::default(),
1098 canonical_method.clone(),
1099 Some(policy.clone()),
1100 VOCAB_SIZE,
1101 hidden,
1102 )
1103 });
1104 online.canonical_method = canonical_method;
1105 online.policy = Some(policy);
1106 online.needs_full_trace = online
1107 .policy
1108 .as_ref()
1109 .map(policy_needs_full_trace)
1110 .unwrap_or(false);
1111 online.full_tbptt = online.needs_full_trace.then(|| FullTbpttRuntime {
1112 pending_input_token: None,
1113 pending_input_pre_state: None,
1114 segment_start_state: None,
1115 steps: Vec::new(),
1116 settings: None,
1117 });
1118 c.online = Some(online);
1119 c.scratch.set_capture_train_trace(
1120 c.online.as_ref().is_some_and(|o| o.needs_full_trace),
1121 );
1122 }
1123 Ok(c)
1124 }
1125 MethodSpec::Online { cfg, policy } => {
1126 let mcfg = cfg.to_mamba_config()?;
1127 let model = if let Some(load_from) =
1128 policy.as_ref().and_then(|p| p.load_from.as_ref())
1129 {
1130 let loaded = Arc::new(Model::load(load_from)?);
1131 let loaded_cfg = loaded.config();
1132 let shape_ok = loaded_cfg.vocab_size == mcfg.vocab_size
1133 && loaded_cfg.hidden_size == mcfg.hidden_size
1134 && loaded_cfg.num_layers == mcfg.num_layers
1135 && loaded_cfg.inner_size == mcfg.inner_size
1136 && loaded_cfg.state_size == mcfg.state_size
1137 && loaded_cfg.conv_kernel == mcfg.conv_kernel
1138 && loaded_cfg.dt_rank == mcfg.dt_rank;
1139 if !shape_ok {
1140 bail!(
1141 "mamba policy load_from shape mismatch with cfg (strict match required)"
1142 );
1143 }
1144 loaded
1145 } else {
1146 Arc::new(Model::new_random(mcfg, cfg.seed)?)
1147 };
1148 let mut c = Self::new_from_model(model);
1149 let mut canonical_method = cfg_to_method_string(&cfg);
1150 if let Some(policy) = policy.as_ref() {
1151 canonical_method.push_str(";policy:");
1152 canonical_method.push_str(&policy.canonical());
1153 }
1154 c.online = Some(OnlineRuntime::new(
1155 cfg,
1156 canonical_method,
1157 policy,
1158 VOCAB_SIZE,
1159 c.model.config().hidden_size,
1160 ));
1161 c.scratch
1162 .set_capture_train_trace(c.online.as_ref().is_some_and(|o| o.needs_full_trace));
1163 Ok(c)
1164 }
1165 }
1166 }
1167
1168 pub fn reset(&mut self) {
1170 self.state.reset();
1171 self.clear_online_training_buffers();
1172 }
1173
1174 fn prepare_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1175 if let Some(online) = self.online.as_mut() {
1176 online.prepare_policy_stream(total_symbols)?;
1177 }
1178 Ok(())
1179 }
1180
1181 fn clear_online_training_buffers(&mut self) {
1182 if let Some(online) = self.online.as_mut()
1183 && let Some(tbptt) = online.full_tbptt.as_mut()
1184 {
1185 tbptt.pending_input_token = None;
1186 tbptt.pending_input_pre_state = None;
1187 tbptt.segment_start_state = None;
1188 tbptt.steps.clear();
1189 tbptt.settings = None;
1190 }
1191 }
1192
1193 fn forward_with_online_record(&mut self, token: u32) {
1194 if let Some(online) = self.online.as_mut()
1195 && let Some(tbptt) = online.full_tbptt.as_mut()
1196 {
1197 tbptt.pending_input_token = Some(token);
1198 tbptt.pending_input_pre_state = Some(self.state.clone());
1199 }
1200 let _ = self
1201 .model
1202 .forward(&mut self.scratch, token, &mut self.state);
1203 }
1204
1205 fn flush_full_tbptt_segment(&mut self) -> Result<()> {
1206 let extracted = {
1207 match self.online.as_mut() {
1208 Some(online) => match online.full_tbptt.as_mut() {
1209 Some(tbptt) if !tbptt.steps.is_empty() => {
1210 let settings = tbptt.settings.take().ok_or_else(|| {
1211 anyhow::anyhow!("mamba full tbptt settings are missing")
1212 })?;
1213 let start_state = tbptt.segment_start_state.take().ok_or_else(|| {
1214 anyhow::anyhow!("mamba full tbptt segment start is missing")
1215 })?;
1216 let steps = std::mem::take(&mut tbptt.steps);
1217 let need_full_adam = matches!(settings.optimizer, OptimizerKind::Adam)
1218 && settings.scope.trains_model_params()
1219 && online.full_adam.is_none();
1220 Some((settings, start_state, steps, need_full_adam))
1221 }
1222 _ => None,
1223 },
1224 None => None,
1225 }
1226 };
1227 let Some((settings, start_state, steps, need_full_adam)) = extracted else {
1228 return Ok(());
1229 };
1230
1231 if need_full_adam {
1232 let full_adam = self.model.new_full_adam_state();
1233 if let Some(online) = self.online.as_mut() {
1234 online.full_adam = Some(full_adam);
1235 }
1236 }
1237
1238 let segment_steps = steps
1239 .into_iter()
1240 .map(|step| (step.input_token, step.target_symbol, step.pdf))
1241 .collect::<Vec<_>>();
1242 let model = Arc::make_mut(&mut self.model);
1243 let Some(online) = self.online.as_mut() else {
1244 return Ok(());
1245 };
1246 model.online_train_segment_tbptt(
1247 &mut self.scratch,
1248 &start_state,
1249 &segment_steps,
1250 settings.scope,
1251 settings.optimizer,
1252 settings.lr,
1253 settings.clip,
1254 TBPTT_REPLAY_CHUNK,
1255 &mut online.adam_t,
1256 online.full_adam.as_mut(),
1257 if settings.scope.bias {
1258 Some(online.out_bias.as_mut_slice())
1259 } else {
1260 None
1261 },
1262 if settings.scope.bias {
1263 online.adam_m.as_deref_mut()
1264 } else {
1265 None
1266 },
1267 if settings.scope.bias {
1268 online.adam_v.as_deref_mut()
1269 } else {
1270 None
1271 },
1272 &mut self.state,
1273 )?;
1274 let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1275 Self::logits_to_pdf(self.scratch.logits(), bias, &mut self.pdf_buffer);
1276 Ok(())
1277 }
1278
1279 fn enqueue_full_tbptt_step(
1280 &mut self,
1281 settings: FullTrainSettings,
1282 target_symbol: u8,
1283 pdf: &[f64],
1284 ) -> Result<()> {
1285 let should_flush = {
1286 let Some(online) = self.online.as_mut() else {
1287 return Ok(());
1288 };
1289 let Some(tbptt) = online.full_tbptt.as_mut() else {
1290 bail!("mamba full-parameter online training requires trace-enabled tbptt runtime");
1291 };
1292 tbptt.settings.is_some_and(|prev| {
1293 !prev.matches(
1294 settings.optimizer,
1295 settings.lr,
1296 settings.scope,
1297 settings.bptt,
1298 settings.clip,
1299 )
1300 }) && !tbptt.steps.is_empty()
1301 };
1302 if should_flush {
1303 self.flush_full_tbptt_segment()?;
1304 }
1305
1306 let flush_now = {
1307 let Some(online) = self.online.as_mut() else {
1308 return Ok(());
1309 };
1310 let Some(tbptt) = online.full_tbptt.as_mut() else {
1311 bail!("mamba full-parameter online training requires trace-enabled tbptt runtime");
1312 };
1313 let Some(input_token) = tbptt.pending_input_token.take() else {
1314 return Ok(());
1315 };
1316 let input_pre_state = tbptt
1317 .pending_input_pre_state
1318 .take()
1319 .ok_or_else(|| anyhow::anyhow!("mamba full tbptt pending pre-state is missing"))?;
1320 if tbptt.steps.is_empty() {
1321 tbptt.segment_start_state = Some(input_pre_state);
1322 }
1323 tbptt.settings = Some(settings);
1324 tbptt.steps.push(FullTbpttStep {
1325 input_token,
1326 target_symbol,
1327 pdf: pdf.to_vec(),
1328 });
1329 tbptt.steps.len() >= settings.bptt.max(1)
1330 };
1331 if flush_now {
1332 self.flush_full_tbptt_segment()?;
1333 }
1334 Ok(())
1335 }
1336
1337 pub fn begin_online_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1339 self.finish_online_policy_stream()?;
1340 self.prepare_policy_stream(total_symbols)
1341 }
1342
1343 pub fn finish_online_policy_stream(&mut self) -> Result<()> {
1345 self.flush_full_tbptt_segment()
1346 }
1347
1348 pub fn restart_online_policy_stream(&mut self, total_symbols: Option<u64>) -> Result<()> {
1350 self.finish_online_policy_stream()?;
1351 self.state.reset();
1352 self.clear_online_training_buffers();
1353 self.prepare_policy_stream(total_symbols)
1354 }
1355
1356 pub fn reset_and_prime(&mut self) {
1358 self.state.reset();
1359 self.clear_online_training_buffers();
1360 self.refresh_current_pdf(0);
1361 }
1362
1363 pub fn snapshot_runtime(&self) -> RuntimeSnapshot {
1365 RuntimeSnapshot {
1366 model: self.model.clone(),
1367 scratch: self.scratch.clone(),
1368 state: self.state.clone(),
1369 pdf_buffer: self.pdf_buffer.clone(),
1370 online: self.online.clone(),
1371 }
1372 }
1373
1374 pub fn restore_runtime(&mut self, snapshot: &RuntimeSnapshot) {
1376 self.model = snapshot.model.clone();
1377 self.scratch = snapshot.scratch.clone();
1378 self.state = snapshot.state.clone();
1379 self.pdf_buffer.clone_from(&snapshot.pdf_buffer);
1380 self.online = snapshot.online.clone();
1381 }
1382
1383 pub fn absorb_chain(&mut self, parts: &[&[u8]]) -> Result<()> {
1385 let total = parts
1386 .iter()
1387 .fold(0u64, |acc, part| acc.saturating_add(part.len() as u64));
1388 self.fit_chain(parts, Some(total))
1389 }
1390
1391 pub fn cross_entropy_from_current(&mut self, data: &[u8]) -> Result<f64> {
1393 if data.is_empty() {
1394 return Ok(0.0);
1395 }
1396 self.begin_online_policy_stream(Some(data.len() as u64))?;
1397 let mut total_bits = 0.0;
1398 for &byte in data {
1399 let p = self.pdf_buffer[byte as usize].max(1e-300);
1400 total_bits -= p.log2();
1401 self.observe_symbol_from_current_pdf(byte)?;
1402 }
1403 self.finish_online_policy_stream()?;
1404 Ok(total_bits / (data.len() as f64))
1405 }
1406
1407 pub fn cross_entropy_frozen_plugin_chain(
1409 &mut self,
1410 fit_parts: &[&[u8]],
1411 data: &[u8],
1412 ) -> Result<f64> {
1413 if data.is_empty() {
1414 return Ok(0.0);
1415 }
1416 if !self.can_adapt_online() {
1417 return self.cross_entropy(data);
1418 }
1419 self.finish_online_policy_stream()?;
1420 self.reset_and_prime();
1421 let fit_total = fit_parts
1422 .iter()
1423 .fold(0u64, |acc, part| acc.saturating_add(part.len() as u64));
1424 self.fit_chain(fit_parts, Some(fit_total))?;
1425 self.reset_and_prime();
1426
1427 let mut total_bits = 0.0;
1428 for &byte in data {
1429 total_bits -= self.pdf_buffer[byte as usize].max(1e-300).log2();
1430 self.advance_inference_only(byte);
1431 }
1432 Ok(total_bits / (data.len() as f64))
1433 }
1434
1435 pub fn is_online(&self) -> bool {
1437 self.online.is_some()
1438 }
1439
1440 pub fn can_adapt_online(&self) -> bool {
1442 let Some(online) = &self.online else {
1443 return false;
1444 };
1445 match &online.policy {
1446 Some(policy) => llm_policy::policy_can_train(policy),
1447 None => !matches!(online.cfg.train_mode, OnlineTrainMode::None),
1448 }
1449 }
1450
1451 pub fn tokens_processed(&self) -> u64 {
1453 self.online.as_ref().map_or(0, |s| s.tokens_processed)
1454 }
1455
1456 pub fn online_method_string(&self) -> Option<&str> {
1458 self.online.as_ref().map(|s| s.canonical_method.as_str())
1459 }
1460
1461 pub fn vocab_size(&self) -> usize {
1463 self.model.config().vocab_size
1464 }
1465
1466 pub fn online_apply_logits_bias(&self, logits: &[f32], pdf_out: &mut [f64]) {
1468 let bias = self.online.as_ref().map(|s| s.out_bias.as_slice());
1469 Self::logits_to_pdf(logits, bias, pdf_out);
1470 }
1471
1472 pub fn logits_to_pdf(logits: &[f32], bias: Option<&[f32]>, pdf_out: &mut [f64]) {
1474 softmax_pdf_floor_with_bias(logits, bias, pdf_out);
1475 }
1476
1477 #[inline]
1478 pub fn forward_to_pdf(&mut self, token: u32, pdf_out: &mut [f64]) {
1480 self.forward_with_online_record(token);
1481 let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1482 Self::logits_to_pdf(self.scratch.logits(), bias, pdf_out);
1483 }
1484
1485 pub fn online_bias_snapshot(&self) -> Option<Vec<f32>> {
1487 self.online.as_ref().map(|o| o.out_bias.clone())
1488 }
1489
1490 #[inline]
1491 pub fn online_bias_slice(&self) -> Option<&[f32]> {
1493 self.online.as_ref().map(|o| o.out_bias.as_slice())
1494 }
1495
1496 pub fn online_update_from_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1498 self.online_update_with_pdf(symbol, pdf)
1499 }
1500
1501 fn resolve_online_train_action(
1502 online: &mut OnlineRuntime,
1503 ) -> Result<(OptimizerKind, f32, u64, mamba1::TrainScopeMask, usize, f32)> {
1504 let mut optimizer = match online.cfg.train_mode {
1505 OnlineTrainMode::None => OptimizerKind::Sgd,
1506 OnlineTrainMode::Sgd => OptimizerKind::Sgd,
1507 OnlineTrainMode::Adam => OptimizerKind::Adam,
1508 };
1509 let mut lr = online.cfg.lr.max(0.0);
1510 let mut stride = online.cfg.stride.max(1) as u64;
1511 let mut scope = mamba1::TrainScopeMask::default();
1512 let default_train = !matches!(online.cfg.train_mode, OnlineTrainMode::None);
1513 scope.head = default_train;
1514 scope.bias = default_train;
1515 let mut bptt = 1usize;
1516 let mut clip = 0.0f32;
1517
1518 if let Some(action) = online.next_policy_action()? {
1519 match action {
1520 PolicyAction::Infer => {
1521 scope = mamba1::TrainScopeMask::default();
1522 }
1523 PolicyAction::Train(train) => {
1524 optimizer = train.optimizer;
1525 lr = train.hyper.lr.max(0.0);
1526 stride = train.hyper.stride.max(1) as u64;
1527 bptt = train.hyper.bptt.max(1);
1528 clip = train.hyper.clip.max(0.0);
1529 if train.scope.all {
1530 scope = mamba1::TrainScopeMask::all();
1531 } else {
1532 scope = mamba1::TrainScopeMask::default();
1533 scope.embed = train.scope.contains("embed");
1534 scope.layer_norm = train.scope.contains("layer_norm");
1535 scope.mixer_conv = train.scope.contains("mixer_conv");
1536 scope.mixer_ssm = train.scope.contains("mixer_ssm");
1537 scope.mixer_proj = train.scope.contains("mixer_proj");
1538 scope.head = train.scope.contains("head");
1539 scope.bias = train.scope.contains("bias");
1540 }
1541 }
1542 }
1543 }
1544 Ok((optimizer, lr, stride, scope, bptt, clip))
1545 }
1546
1547 #[inline]
1548 pub fn observe_symbol_from_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1550 self.online_update_with_pdf(symbol, pdf)?;
1551 self.refresh_current_pdf(symbol as u32);
1552 Ok(())
1553 }
1554
1555 fn online_update_with_pdf(&mut self, symbol: u8, pdf: &[f64]) -> Result<()> {
1556 let (optimizer, lr, stride_hit, scope, bptt, clip) = {
1557 let Some(online) = self.online.as_mut() else {
1558 return Ok(());
1559 };
1560 online.tokens_processed = online.tokens_processed.saturating_add(1);
1561 let (optimizer, lr, stride, scope, bptt, clip) =
1562 Self::resolve_online_train_action(online)?;
1563 let mut stride_hit = false;
1564 if scope.trains_model_params() || scope.bias {
1565 online.policy_train_steps = online.policy_train_steps.saturating_add(1);
1566 stride_hit = stride <= 1 || (online.policy_train_steps % stride) == 0;
1567 }
1568 (optimizer, lr, stride_hit, scope, bptt, clip)
1569 };
1570
1571 if (!scope.trains_model_params() && !scope.bias) || !stride_hit || lr == 0.0 {
1572 self.flush_full_tbptt_segment()?;
1573 if let Some(online) = self.online.as_mut()
1574 && let Some(tbptt) = online.full_tbptt.as_mut()
1575 {
1576 tbptt.pending_input_token = None;
1577 tbptt.pending_input_pre_state = None;
1578 }
1579 return Ok(());
1580 }
1581
1582 if matches!(optimizer, OptimizerKind::Adam)
1583 && let Some(online) = self.online.as_mut()
1584 && scope.bias
1585 && (online.adam_m.is_none() || online.adam_v.is_none())
1586 {
1587 online.adam_m = Some(vec![0.0; online.out_bias.len()]);
1588 online.adam_v = Some(vec![0.0; online.out_bias.len()]);
1589 }
1590
1591 let trains_non_head = scope.embed
1592 || scope.layer_norm
1593 || scope.mixer_conv
1594 || scope.mixer_ssm
1595 || scope.mixer_proj;
1596 if trains_non_head && bptt > 1 {
1597 let settings = FullTrainSettings {
1598 optimizer,
1599 lr,
1600 scope,
1601 bptt,
1602 clip,
1603 };
1604 return self.enqueue_full_tbptt_step(settings, symbol, pdf);
1605 }
1606
1607 self.flush_full_tbptt_segment()?;
1608 if let Some(online) = self.online.as_mut()
1609 && let Some(tbptt) = online.full_tbptt.as_mut()
1610 {
1611 tbptt.pending_input_token = None;
1612 tbptt.pending_input_pre_state = None;
1613 }
1614 if scope.trains_model_params() {
1615 self.scratch.set_capture_train_trace(true);
1616 }
1617 if matches!(optimizer, OptimizerKind::Adam)
1618 && scope.trains_model_params()
1619 && self.online.as_ref().is_some_and(|o| o.full_adam.is_none())
1620 {
1621 let full_adam = self.model.as_ref().new_full_adam_state();
1622 if let Some(online) = self.online.as_mut()
1623 && online.full_adam.is_none()
1624 {
1625 online.full_adam = Some(full_adam);
1626 }
1627 }
1628
1629 let model = Arc::make_mut(&mut self.model);
1630 let Some(online) = self.online.as_mut() else {
1631 return Ok(());
1632 };
1633 let OnlineRuntime {
1634 out_bias,
1635 adam_m,
1636 adam_v,
1637 full_adam,
1638 adam_t,
1639 ..
1640 } = online;
1641 model.online_train_step_bptt1(
1642 &mut self.scratch,
1643 &self.state,
1644 symbol,
1645 pdf,
1646 scope,
1647 optimizer,
1648 lr,
1649 clip,
1650 adam_t,
1651 full_adam.as_mut(),
1652 if scope.bias {
1653 Some(out_bias.as_mut_slice())
1654 } else {
1655 None
1656 },
1657 if scope.bias {
1658 adam_m.as_deref_mut()
1659 } else {
1660 None
1661 },
1662 if scope.bias {
1663 adam_v.as_deref_mut()
1664 } else {
1665 None
1666 },
1667 )
1668 }
1669
1670 #[inline]
1671 fn refresh_current_pdf(&mut self, token: u32) {
1672 self.forward_with_online_record(token);
1673 let bias = self.online.as_ref().map(|o| o.out_bias.as_slice());
1674 Self::logits_to_pdf(self.scratch.logits(), bias, &mut self.pdf_buffer);
1675 }
1676
1677 fn fit_chain(&mut self, parts: &[&[u8]], total_symbols: Option<u64>) -> Result<()> {
1678 self.begin_online_policy_stream(total_symbols)?;
1679 for part in parts {
1680 for &byte in *part {
1681 self.observe_symbol_from_current_pdf(byte)?;
1682 }
1683 }
1684 self.finish_online_policy_stream()?;
1685 Ok(())
1686 }
1687
1688 #[inline]
1689 fn advance_inference_only(&mut self, symbol: u8) {
1690 self.refresh_current_pdf(symbol as u32);
1691 }
1692
1693 fn online_update_from_current_pdf(&mut self, symbol: u8) -> Result<()> {
1694 let pdf_snapshot = self.pdf_buffer.clone();
1695 self.online_update_with_pdf(symbol, &pdf_snapshot)
1696 }
1697
1698 #[inline]
1699 pub fn observe_symbol_from_current_pdf(&mut self, symbol: u8) -> Result<()> {
1701 self.online_update_from_current_pdf(symbol)?;
1702 self.refresh_current_pdf(symbol as u32);
1703 Ok(())
1704 }
1705
1706 pub fn export_online<P: AsRef<Path>>(&self, model_path: P) -> Result<()> {
1708 let model_path = model_path.as_ref();
1709 self.model.save_safetensors(model_path)?;
1710 let opt_sidecar = optimizer_sidecar_path(model_path);
1711
1712 let sidecar = model_path.with_extension("json");
1713 let meta = if let Some(online) = &self.online {
1714 if let Some(full_adam) = online.full_adam.as_ref() {
1715 self.model
1716 .save_full_adam_safetensors(full_adam, &opt_sidecar)?;
1717 } else if opt_sidecar.exists() {
1718 let _ = fs::remove_file(&opt_sidecar);
1719 }
1720 let train_mode = match online.cfg.train_mode {
1721 OnlineTrainMode::None => "none",
1722 OnlineTrainMode::Sgd => "sgd",
1723 OnlineTrainMode::Adam => "adam",
1724 };
1725 json!({
1726 "version": 1,
1727 "method": online.canonical_method,
1728 "policy": online.policy.as_ref().map(LlmPolicy::canonical),
1729 "policy_cursor": online.policy_runtime.as_ref().map(PolicyRuntime::cursor).unwrap_or(0),
1730 "policy_stream_total": online.policy_stream_total,
1731 "policy_train_steps": online.policy_train_steps,
1732 "training_mode": train_mode,
1733 "tokens_processed": online.tokens_processed,
1734 "adam_t": online.adam_t,
1735 "has_full_adam": online.full_adam.is_some(),
1736 "config": {
1737 "hidden": online.cfg.hidden,
1738 "layers": online.cfg.layers,
1739 "intermediate": online.cfg.intermediate,
1740 "state": online.cfg.state,
1741 "conv": online.cfg.conv,
1742 "dt_rank": online.cfg.dt_rank,
1743 "seed": online.cfg.seed,
1744 "lr": online.cfg.lr,
1745 "stride": online.cfg.stride.max(1),
1746 },
1747 "output_bias": online.out_bias,
1748 "adam_m": online.adam_m,
1749 "adam_v": online.adam_v,
1750 "lm_head_adam_m": online.lm_head_adam_m,
1751 "lm_head_adam_v": online.lm_head_adam_v,
1752 })
1753 } else {
1754 if opt_sidecar.exists() {
1755 let _ = fs::remove_file(&opt_sidecar);
1756 }
1757 json!({
1758 "version": 1,
1759 "method": format!("file:{}", model_path.display()),
1760 "training_mode": "none",
1761 "tokens_processed": 0,
1762 })
1763 };
1764
1765 fs::write(sidecar, serde_json::to_vec_pretty(&meta)?)?;
1766 Ok(())
1767 }
1768
1769 fn maybe_load_sidecar(&mut self) -> Result<()> {
1770 let Some(model_path) = &self.source_model_path else {
1771 return Ok(());
1772 };
1773 let sidecar = model_path.with_extension("json");
1774 if !sidecar.exists() {
1775 return Ok(());
1776 }
1777
1778 let raw = fs::read(&sidecar)?;
1779 let v: serde_json::Value = serde_json::from_slice(&raw)?;
1780 let parse_vec_f32 = |key: &str| -> Option<Vec<f32>> {
1781 v.get(key).and_then(|arr| arr.as_array()).map(|arr| {
1782 arr.iter()
1783 .map(|x| x.as_f64().unwrap_or(0.0) as f32)
1784 .collect::<Vec<f32>>()
1785 })
1786 };
1787 let output_bias = v
1788 .get("output_bias")
1789 .and_then(|arr| arr.as_array())
1790 .map(|arr| {
1791 arr.iter()
1792 .map(|x| x.as_f64().unwrap_or(0.0) as f32)
1793 .collect::<Vec<f32>>()
1794 });
1795
1796 let method = v
1797 .get("method")
1798 .and_then(|m| m.as_str())
1799 .map(|s| s.to_string())
1800 .unwrap_or_else(|| format!("file:{}", model_path.display()));
1801 let has_full_adam = v
1802 .get("has_full_adam")
1803 .and_then(|x| x.as_bool())
1804 .unwrap_or(false);
1805 let policy = v
1806 .get("policy")
1807 .and_then(|p| p.as_str())
1808 .and_then(|s| llm_policy::parse_policy_segment(s, MAMBA_TRAIN_SCOPES).ok());
1809 let tokens = v
1810 .get("tokens_processed")
1811 .and_then(|t| t.as_u64())
1812 .unwrap_or(0);
1813
1814 if let Some(mut out_bias) = output_bias {
1815 out_bias.resize(self.vocab_size(), 0.0);
1816 let mut cfg = OnlineConfig::default();
1817 if let Some(cfg_v) = v.get("config").and_then(|x| x.as_object()) {
1818 if let Some(x) = cfg_v.get("hidden").and_then(|x| x.as_u64()) {
1819 cfg.hidden = x as usize;
1820 }
1821 if let Some(x) = cfg_v.get("layers").and_then(|x| x.as_u64()) {
1822 cfg.layers = x as usize;
1823 }
1824 if let Some(x) = cfg_v.get("intermediate").and_then(|x| x.as_u64()) {
1825 cfg.intermediate = x as usize;
1826 }
1827 if let Some(x) = cfg_v.get("state").and_then(|x| x.as_u64()) {
1828 cfg.state = x as usize;
1829 }
1830 if let Some(x) = cfg_v.get("conv").and_then(|x| x.as_u64()) {
1831 cfg.conv = x as usize;
1832 }
1833 if let Some(x) = cfg_v.get("dt_rank").and_then(|x| x.as_u64()) {
1834 cfg.dt_rank = x as usize;
1835 }
1836 if let Some(x) = cfg_v.get("seed").and_then(|x| x.as_u64()) {
1837 cfg.seed = x;
1838 }
1839 if let Some(x) = cfg_v.get("lr").and_then(|x| x.as_f64()) {
1840 cfg.lr = x as f32;
1841 }
1842 if let Some(x) = cfg_v.get("stride").and_then(|x| x.as_u64()) {
1843 cfg.stride = (x as usize).max(1);
1844 }
1845 }
1846 cfg.train_mode = v
1847 .get("training_mode")
1848 .and_then(|x| x.as_str())
1849 .and_then(|s| parse_train_mode_token(s).ok())
1850 .unwrap_or(OnlineTrainMode::None);
1851 let needs_full_trace = policy
1852 .as_ref()
1853 .map(policy_needs_full_trace)
1854 .unwrap_or(false);
1855
1856 self.online = Some(OnlineRuntime {
1857 cfg,
1858 canonical_method: method,
1859 policy,
1860 policy_runtime: None,
1861 needs_full_trace,
1862 policy_stream_total: v.get("policy_stream_total").and_then(|x| x.as_u64()),
1863 policy_train_steps: v
1864 .get("policy_train_steps")
1865 .and_then(|x| x.as_u64())
1866 .unwrap_or(0),
1867 tokens_processed: tokens,
1868 out_bias,
1869 adam_m: parse_vec_f32("adam_m"),
1870 adam_v: parse_vec_f32("adam_v"),
1871 full_adam: None,
1872 lm_head_adam_m: parse_vec_f32("lm_head_adam_m"),
1873 lm_head_adam_v: parse_vec_f32("lm_head_adam_v"),
1874 adam_t: v.get("adam_t").and_then(|x| x.as_u64()).unwrap_or(0) as usize,
1875 full_tbptt: needs_full_trace.then(|| FullTbpttRuntime {
1876 pending_input_token: None,
1877 pending_input_pre_state: None,
1878 segment_start_state: None,
1879 steps: Vec::new(),
1880 settings: None,
1881 }),
1882 });
1883 let opt_sidecar = optimizer_sidecar_path(model_path);
1884 if opt_sidecar.exists() {
1885 if let Some(online) = self.online.as_mut() {
1886 online.full_adam = Some(self.model.load_full_adam_safetensors(&opt_sidecar)?);
1887 }
1888 } else if has_full_adam {
1889 bail!(
1890 "missing optimizer sidecar '{}' required for exact online resume",
1891 opt_sidecar.display()
1892 );
1893 }
1894 if let Some(cursor) = v.get("policy_cursor").and_then(|x| x.as_u64())
1895 && let Some(online) = self.online.as_mut()
1896 && online.policy.is_some()
1897 {
1898 let train_steps = online.policy_train_steps;
1899 online.prepare_policy_stream(online.policy_stream_total)?;
1900 online.policy_train_steps = train_steps;
1901 if let Some(rt) = online.policy_runtime.as_mut() {
1902 rt.set_cursor(cursor);
1903 }
1904 }
1905 self.scratch
1906 .set_capture_train_trace(self.online.as_ref().is_some_and(|o| o.needs_full_trace));
1907 }
1908 Ok(())
1909 }
1910
1911 pub fn compress_into<W: Write>(
1913 &mut self,
1914 data: &[u8],
1915 coder: CoderType,
1916 w: &mut W,
1917 ) -> Result<()> {
1918 self.restart_online_policy_stream(Some(data.len() as u64))?;
1919 let checksum = crc32(data);
1920 let header = Header::new(coder, data.len() as u64, checksum);
1921 header.write(w)?;
1922
1923 match coder {
1924 CoderType::AC => self.compress_ac_iter(data.iter().copied(), w)?,
1925 CoderType::RANS => self.compress_rans_iter(data.iter().copied(), w)?,
1926 }
1927 self.finish_online_policy_stream()?;
1928 Ok(())
1929 }
1930
1931 pub fn compress_chain_into<W: Write>(
1933 &mut self,
1934 parts: &[&[u8]],
1935 coder: CoderType,
1936 w: &mut W,
1937 ) -> Result<()> {
1938 let mut total_len: u64 = 0;
1939 let mut hasher = crc32fast::Hasher::new();
1940 for p in parts {
1941 total_len = total_len.saturating_add(p.len() as u64);
1942 hasher.update(p);
1943 }
1944 self.restart_online_policy_stream(Some(total_len))?;
1945 let checksum = hasher.finalize();
1946 let header = Header::new(coder, total_len, checksum);
1947 header.write(w)?;
1948
1949 let it = parts.iter().flat_map(|p| p.iter().copied());
1950 match coder {
1951 CoderType::AC => self.compress_ac_iter(it, w)?,
1952 CoderType::RANS => self.compress_rans_iter(it, w)?,
1953 }
1954 self.finish_online_policy_stream()?;
1955 Ok(())
1956 }
1957
1958 pub fn compress_size(&mut self, data: &[u8], coder: CoderType) -> Result<u64> {
1960 let mut w = CountingWriter::new();
1961 self.compress_into(data, coder, &mut w)?;
1962 Ok(w.bytes_written())
1963 }
1964
1965 pub fn compress_size_chain(&mut self, parts: &[&[u8]], coder: CoderType) -> Result<u64> {
1967 let mut w = CountingWriter::new();
1968 self.compress_chain_into(parts, coder, &mut w)?;
1969 Ok(w.bytes_written())
1970 }
1971
1972 pub fn compress(&mut self, data: &[u8], coder: CoderType) -> Result<Vec<u8>> {
1974 let mut out = Vec::new();
1975 self.compress_into(data, coder, &mut out)?;
1976 Ok(out)
1977 }
1978
1979 fn compress_ac_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
1980 where
1981 I: IntoIterator<Item = u8>,
1982 {
1983 let mut encoder = ArithmeticEncoder::new(output);
1984
1985 self.refresh_current_pdf(0);
1986
1987 for byte in data {
1988 quantize_pdf_to_cdf_with_buffer(
1989 &self.pdf_buffer,
1990 &mut self.cdf_buffer_ac,
1991 &mut self.ac_freq_buffer,
1992 );
1993 let sym = byte as usize;
1994 let lo = self.cdf_buffer_ac[sym] as u64;
1995 let hi = self.cdf_buffer_ac[sym + 1] as u64;
1996 encoder.encode_counts(lo, hi, CDF_TOTAL as u64)?;
1997 self.observe_symbol_from_current_pdf(byte)?;
1998 }
1999
2000 let _ = encoder.finish()?;
2001 Ok(())
2002 }
2003
2004 fn compress_rans_iter<I, W: Write>(&mut self, data: I, output: &mut W) -> Result<()>
2005 where
2006 I: IntoIterator<Item = u8>,
2007 {
2008 let mut encoder = BlockedRansEncoder::new();
2009
2010 self.refresh_current_pdf(0);
2011
2012 for byte in data {
2013 quantize_pdf_to_rans_cdf_with_buffer(
2014 &self.pdf_buffer,
2015 &mut self.cdf_buffer_rans,
2016 &mut self.rans_freq_buffer,
2017 );
2018 let sym = byte as usize;
2019 let cdf = Cdf::new(
2020 self.cdf_buffer_rans[sym],
2021 self.cdf_buffer_rans[sym + 1],
2022 ANS_TOTAL,
2023 );
2024 encoder.encode(cdf);
2025 self.observe_symbol_from_current_pdf(byte)?;
2026 }
2027
2028 let blocks = encoder.finish();
2029 output.write_all(&(blocks.len() as u32).to_le_bytes())?;
2030 for block in &blocks {
2031 output.write_all(&(block.len() as u32).to_le_bytes())?;
2032 output.write_all(block)?;
2033 }
2034 Ok(())
2035 }
2036
2037 pub fn decompress(&mut self, data: &[u8]) -> Result<Vec<u8>> {
2039 let mut cursor = Cursor::new(data);
2040 let header = Header::read(&mut cursor)?;
2041
2042 self.restart_online_policy_stream(Some(header.original_len))?;
2043 let compressed = &data[Header::SIZE..];
2044 let result = match header.coder_type() {
2045 CoderType::AC => self.decompress_ac(compressed, header.original_len as usize)?,
2046 CoderType::RANS => self.decompress_rans(compressed, header.original_len as usize)?,
2047 };
2048
2049 let actual_crc = crc32(&result);
2050 if actual_crc != header.crc32 {
2051 bail!(
2052 "CRC32 mismatch: expected 0x{:08X}, got 0x{:08X}",
2053 header.crc32,
2054 actual_crc
2055 );
2056 }
2057 self.finish_online_policy_stream()?;
2058 Ok(result)
2059 }
2060
2061 fn decompress_ac(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
2062 let mut decoder = ArithmeticDecoder::new(compressed)?;
2063 let mut result = Vec::with_capacity(original_len);
2064
2065 self.refresh_current_pdf(0);
2066
2067 for _ in 0..original_len {
2068 quantize_pdf_to_cdf_with_buffer(
2069 &self.pdf_buffer,
2070 &mut self.cdf_buffer_ac,
2071 &mut self.ac_freq_buffer,
2072 );
2073 let sym = decoder.decode_symbol_counts(&self.cdf_buffer_ac, CDF_TOTAL)?;
2074 let byte = sym as u8;
2075 result.push(byte);
2076 self.observe_symbol_from_current_pdf(byte)?;
2077 }
2078
2079 Ok(result)
2080 }
2081
2082 fn decompress_rans(&mut self, compressed: &[u8], original_len: usize) -> Result<Vec<u8>> {
2083 if compressed.len() < 4 {
2084 bail!("rANS data too short");
2085 }
2086 let block_count =
2087 u32::from_le_bytes([compressed[0], compressed[1], compressed[2], compressed[3]])
2088 as usize;
2089
2090 let mut blocks = Vec::with_capacity(block_count);
2091 let mut pos = 4usize;
2092 for _ in 0..block_count {
2093 if pos + 4 > compressed.len() {
2094 bail!("truncated rANS block header");
2095 }
2096 let len = u32::from_le_bytes([
2097 compressed[pos],
2098 compressed[pos + 1],
2099 compressed[pos + 2],
2100 compressed[pos + 3],
2101 ]) as usize;
2102 pos += 4;
2103 if pos + len > compressed.len() {
2104 bail!("truncated rANS block data");
2105 }
2106 blocks.push(&compressed[pos..pos + len]);
2107 pos += len;
2108 }
2109
2110 let mut decoder = BlockedRansDecoder::new(blocks, original_len)?;
2111 let mut result = Vec::with_capacity(original_len);
2112
2113 self.refresh_current_pdf(0);
2114
2115 for _ in 0..original_len {
2116 quantize_pdf_to_rans_cdf_with_buffer(
2117 &self.pdf_buffer,
2118 &mut self.cdf_buffer_rans,
2119 &mut self.rans_freq_buffer,
2120 );
2121 let sym = decoder.decode(&self.cdf_buffer_rans)? as u8;
2122 result.push(sym);
2123 self.observe_symbol_from_current_pdf(sym)?;
2124 }
2125
2126 Ok(result)
2127 }
2128
2129 pub fn cross_entropy(&mut self, data: &[u8]) -> Result<f64> {
2131 self.reset_and_prime();
2132 self.cross_entropy_from_current(data)
2133 }
2134
2135 pub fn cross_entropy_conditional_chain(
2137 &mut self,
2138 prefix_parts: &[&[u8]],
2139 data: &[u8],
2140 ) -> Result<f64> {
2141 if data.is_empty() {
2142 return Ok(0.0);
2143 }
2144 let prefix_len = prefix_parts
2145 .iter()
2146 .fold(0usize, |acc, p| acc.saturating_add(p.len()));
2147 self.finish_online_policy_stream()?;
2148 self.reset_and_prime();
2149 self.fit_chain(prefix_parts, Some((prefix_len + data.len()) as u64))?;
2150
2151 let mut total_bits = 0.0;
2152 for &byte in data {
2153 total_bits -= self.pdf_buffer[byte as usize].max(1e-300).log2();
2154 self.observe_symbol_from_current_pdf(byte)?;
2155 }
2156 self.finish_online_policy_stream()?;
2157 Ok(total_bits / (data.len() as f64))
2158 }
2159
2160 pub fn cross_entropy_conditional(&mut self, prefix: &[u8], data: &[u8]) -> Result<f64> {
2162 self.cross_entropy_conditional_chain(&[prefix], data)
2163 }
2164
2165 pub fn joint_cross_entropy_aligned_min(&mut self, x: &[u8], y: &[u8]) -> Result<f64> {
2167 let n = x.len().min(y.len());
2168 if n == 0 {
2169 return Ok(0.0);
2170 }
2171 let h_xy = self.joint_cross_entropy_aligned_order(x, y, false)?;
2172 let h_yx = self.joint_cross_entropy_aligned_order(x, y, true)?;
2173 Ok(h_xy.min(h_yx))
2174 }
2175
2176 fn joint_cross_entropy_aligned_order(&mut self, x: &[u8], y: &[u8], swap: bool) -> Result<f64> {
2177 let n = x.len().min(y.len());
2178 if n == 0 {
2179 return Ok(0.0);
2180 }
2181
2182 self.restart_online_policy_stream(Some((2 * n) as u64))?;
2183
2184 self.refresh_current_pdf(0);
2185
2186 let mut total_bits = 0.0;
2187 for idx in 0..n {
2188 let a = if swap { y[idx] } else { x[idx] };
2189 let b = if swap { x[idx] } else { y[idx] };
2190
2191 total_bits -= self.pdf_buffer[a as usize].max(1e-300).log2();
2192 self.observe_symbol_from_current_pdf(a)?;
2193
2194 total_bits -= self.pdf_buffer[b as usize].max(1e-300).log2();
2195 self.observe_symbol_from_current_pdf(b)?;
2196 }
2197
2198 self.finish_online_policy_stream()?;
2199 Ok(total_bits / (n as f64))
2200 }
2201}