Skip to main content

infotheory/
diagnostics.rs

1//! Exact AC/log-loss diagnostics for mixture compression.
2
3use anyhow::{Context, Result, bail};
4
5use crate::compression::{AcLogLossNodeValue, DiagnosticRatePredictor};
6use crate::mixture::RateBackendPredictor;
7use crate::{MixtureExpertSpec, MixtureKind, MixtureSpec, RateBackend};
8use std::env;
9use std::fs::File;
10use std::io::{BufWriter, Write};
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13
14#[derive(Clone, Debug)]
15struct FlatNodeMeta {
16    id: usize,
17    parent_id: Option<usize>,
18    depth: usize,
19    path: String,
20    display_name: String,
21    backend_label: String,
22    is_mixture: bool,
23    is_leaf: bool,
24    is_root_child: bool,
25}
26
27#[derive(Clone, Debug)]
28struct FlatSchema {
29    nodes: Vec<FlatNodeMeta>,
30    non_root_ids: Vec<usize>,
31    root_child_ids: Vec<usize>,
32}
33
34#[derive(Clone, Copy, Debug, Default)]
35struct NodeSummaryAccum {
36    total_bits: f64,
37    total_local_weight: f64,
38    total_effective_weight: f64,
39    oracle_win_count: u64,
40}
41
42#[derive(Clone, Debug)]
43/// Output summary for an AC/log-loss diagnostic run.
44///
45/// The diagnostic writer emits three TSV files sharing a common prefix:
46///
47/// - `*.trace.tsv`: per-position mixture/expert probabilities and weights
48/// - `*.nodes.tsv`: flattened mixture-node schema used by trace columns
49/// - `*.summary.tsv`: aggregate totals/averages over the full sequence
50pub struct AcLogLossRunSummary {
51    /// Path to the generated per-position trace TSV.
52    pub trace_path: PathBuf,
53    /// Path to the generated node-schema TSV.
54    pub nodes_path: PathBuf,
55    /// Path to the generated aggregate summary TSV.
56    pub summary_path: PathBuf,
57    /// Number of processed input positions.
58    pub positions: usize,
59    /// Total mixture code length in bits, computed from mixture probabilities.
60    pub mix_total_bits: f64,
61    /// Total oracle code length in bits, using best per-step expert in hindsight.
62    pub oracle_total_bits: f64,
63    /// Raw arithmetic-coder payload size in bits.
64    pub ac_payload_bits_raw: u64,
65}
66
67#[derive(Default)]
68struct CountingWriter {
69    bytes_written: u64,
70}
71
72impl Write for CountingWriter {
73    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
74        self.bytes_written = self.bytes_written.saturating_add(buf.len() as u64);
75        Ok(buf.len())
76    }
77
78    fn flush(&mut self) -> std::io::Result<()> {
79        Ok(())
80    }
81}
82
83fn parse_diagnostic_threads_from_env() -> Result<usize> {
84    match env::var("RAYON_NUM_THREADS") {
85        Ok(raw) => {
86            let threads = raw
87                .parse::<usize>()
88                .with_context(|| format!("invalid RAYON_NUM_THREADS value '{raw}'"))?;
89            if threads == 0 {
90                bail!("RAYON_NUM_THREADS must be >= 1");
91            }
92            Ok(threads)
93        }
94        Err(env::VarError::NotPresent) => Ok(1),
95        Err(err) => Err(err).context("failed to read RAYON_NUM_THREADS"),
96    }
97}
98
99fn sanitize_tsv_text(input: &str) -> String {
100    input
101        .chars()
102        .map(|ch| match ch {
103            '\t' | '\n' | '\r' => ' ',
104            _ => ch,
105        })
106        .collect()
107}
108
109fn sanitize_path_segment(input: &str) -> String {
110    let clean = sanitize_tsv_text(input);
111    let mut out = String::with_capacity(clean.len());
112    for ch in clean.chars() {
113        match ch {
114            '/' | '\\' => out.push('_'),
115            _ => out.push(ch),
116        }
117    }
118    if out.is_empty() {
119        "node".to_string()
120    } else {
121        out
122    }
123}
124
125fn format_f64(value: f64) -> String {
126    format!("{value:.17e}")
127}
128
129fn format_optional_usize(value: Option<usize>) -> String {
130    value.map(|v| v.to_string()).unwrap_or_default()
131}
132
133fn bool_flag(value: bool) -> &'static str {
134    if value { "1" } else { "0" }
135}
136
137fn bits_from_prob(prob: f64) -> f64 {
138    -prob.max(crate::mixture::DEFAULT_MIN_PROB).log2()
139}
140
141fn mixture_kind_label(kind: MixtureKind) -> &'static str {
142    match kind {
143        MixtureKind::Bayes => "mixture:bayes",
144        MixtureKind::FadingBayes => "mixture:fading-bayes",
145        MixtureKind::Switching => "mixture:switching",
146        MixtureKind::Convex => "mixture:convex",
147        MixtureKind::Mdl => "mixture:mdl",
148        MixtureKind::Neural => "mixture:neural",
149    }
150}
151
152fn backend_label(expert: &MixtureExpertSpec) -> String {
153    match &expert.backend {
154        RateBackend::RosaPlus => format!("rosaplus(max_order={})", expert.max_order),
155        RateBackend::Match {
156            hash_bits,
157            min_len,
158            max_len,
159            base_mix,
160            confidence_scale,
161        } => format!(
162            "match(hash_bits={hash_bits},min_len={min_len},max_len={max_len},base_mix={base_mix},confidence_scale={confidence_scale})"
163        ),
164        RateBackend::SparseMatch {
165            hash_bits,
166            min_len,
167            max_len,
168            gap_min,
169            gap_max,
170            base_mix,
171            confidence_scale,
172        } => format!(
173            "sparse-match(hash_bits={hash_bits},min_len={min_len},max_len={max_len},gap_min={gap_min},gap_max={gap_max},base_mix={base_mix},confidence_scale={confidence_scale})"
174        ),
175        RateBackend::Ppmd { order, memory_mb } => {
176            format!("ppmd(order={order},memory_mb={memory_mb})")
177        }
178        RateBackend::Sequitur { context_bytes } => {
179            format!("sequitur(context_bytes={context_bytes})")
180        }
181        RateBackend::Ctw { depth } => format!("ctw(depth={depth})"),
182        RateBackend::FacCtw {
183            base_depth,
184            num_percept_bits,
185            encoding_bits,
186        } => format!(
187            "fac-ctw(base_depth={base_depth},num_percept_bits={num_percept_bits},encoding_bits={encoding_bits})"
188        ),
189        #[cfg(feature = "backend-mamba")]
190        RateBackend::Mamba { .. } => "mamba".to_string(),
191        #[cfg(feature = "backend-mamba")]
192        RateBackend::MambaMethod { method } => format!("mamba(method={method})"),
193        #[cfg(feature = "backend-rwkv")]
194        RateBackend::Rwkv7 { .. } => "rwkv7".to_string(),
195        #[cfg(feature = "backend-rwkv")]
196        RateBackend::Rwkv7Method { method } => format!("rwkv7(method={method})"),
197        RateBackend::Zpaq { method } => format!("zpaq(method={method})"),
198        RateBackend::Mixture { spec } => mixture_kind_label(spec.kind).to_string(),
199        RateBackend::Particle { spec } => format!(
200            "particle(num_particles={},num_cells={})",
201            spec.num_particles, spec.num_cells
202        ),
203        RateBackend::Calibrated { spec } => format!(
204            "calibrated(context={:?},bins={},learning_rate={},bias_clip={})",
205            spec.context, spec.bins, spec.learning_rate, spec.bias_clip
206        ),
207    }
208}
209
210fn flatten_mixture_spec(spec: &MixtureSpec) -> FlatSchema {
211    let mut schema = FlatSchema {
212        nodes: vec![FlatNodeMeta {
213            id: 0,
214            parent_id: None,
215            depth: 0,
216            path: "0:root".to_string(),
217            display_name: "root".to_string(),
218            backend_label: mixture_kind_label(spec.kind).to_string(),
219            is_mixture: true,
220            is_leaf: false,
221            is_root_child: false,
222        }],
223        non_root_ids: Vec::new(),
224        root_child_ids: Vec::new(),
225    };
226    flatten_experts(&mut schema, &spec.experts, 0, 1, "0:root", true);
227    schema
228}
229
230fn flatten_experts(
231    schema: &mut FlatSchema,
232    experts: &[MixtureExpertSpec],
233    parent_id: usize,
234    depth: usize,
235    parent_path: &str,
236    root_level: bool,
237) {
238    for expert in experts {
239        let raw_display_name = expert.name.clone().unwrap_or_else(|| {
240            RateBackendPredictor::default_name(&expert.backend, expert.max_order)
241        });
242        let display_name = sanitize_tsv_text(&raw_display_name);
243        let node_id = schema.nodes.len();
244        let path = format!(
245            "{parent_path}/{}:{}",
246            node_id,
247            sanitize_path_segment(&display_name)
248        );
249        let is_mixture = matches!(expert.backend, RateBackend::Mixture { .. });
250        let meta = FlatNodeMeta {
251            id: node_id,
252            parent_id: Some(parent_id),
253            depth,
254            path,
255            display_name,
256            backend_label: sanitize_tsv_text(&backend_label(expert)),
257            is_mixture,
258            is_leaf: !is_mixture,
259            is_root_child: root_level,
260        };
261        schema.nodes.push(meta);
262        schema.non_root_ids.push(node_id);
263        if root_level {
264            schema.root_child_ids.push(node_id);
265        }
266        if let RateBackend::Mixture { spec } = &expert.backend {
267            let node_path = schema.nodes[node_id].path.clone();
268            flatten_experts(schema, &spec.experts, node_id, depth + 1, &node_path, false);
269        }
270    }
271}
272
273fn write_nodes_tsv(path: &Path, schema: &FlatSchema) -> Result<()> {
274    let file = File::create(path)
275        .with_context(|| format!("failed to create nodes TSV {}", path.display()))?;
276    let mut writer = BufWriter::new(file);
277    writeln!(
278        writer,
279        "node_id\tparent_id\tdepth\tpath\tdisplay_name\tbackend_label\tis_mixture\tis_leaf\tis_root_child"
280    )?;
281    for node in &schema.nodes {
282        writeln!(
283            writer,
284            "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}",
285            node.id,
286            format_optional_usize(node.parent_id),
287            node.depth,
288            node.path,
289            node.display_name,
290            node.backend_label,
291            bool_flag(node.is_mixture),
292            bool_flag(node.is_leaf),
293            bool_flag(node.is_root_child),
294        )?;
295    }
296    writer.flush()?;
297    Ok(())
298}
299
300/// Run exact AC/log-loss diagnostics for a byte sequence under a mixture spec.
301///
302/// This function validates `spec`, evaluates mixture and expert probabilities at
303/// each input position, and writes three TSV artifacts using `out_prefix`:
304///
305/// - `out_prefix.trace.tsv`: per-position diagnostics and expert rows
306/// - `out_prefix.nodes.tsv`: flattened node metadata for trace column mapping
307/// - `out_prefix.summary.tsv`: aggregate totals and averages
308///
309/// The returned summary includes key scalar metrics and the concrete output
310/// paths.
311pub fn run_ac_log_loss_mixture_bytes(
312    data: &[u8],
313    spec: &MixtureSpec,
314    out_prefix: impl AsRef<Path>,
315) -> Result<AcLogLossRunSummary> {
316    spec.validate().map_err(anyhow::Error::msg)?;
317
318    let out_prefix = out_prefix.as_ref();
319    if let Some(parent) = out_prefix.parent()
320        && !parent.as_os_str().is_empty()
321    {
322        std::fs::create_dir_all(parent)
323            .with_context(|| format!("failed to create output directory {}", parent.display()))?;
324    }
325
326    let trace_path = out_prefix.with_extension("trace.tsv");
327    let nodes_path = out_prefix.with_extension("nodes.tsv");
328    let summary_path = out_prefix.with_extension("summary.tsv");
329    let schema = flatten_mixture_spec(spec);
330    write_nodes_tsv(&nodes_path, &schema)?;
331
332    let threads = parse_diagnostic_threads_from_env()?;
333    let pool = if threads > 1 {
334        Some(
335            rayon::ThreadPoolBuilder::new()
336                .num_threads(threads)
337                .build()
338                .context("failed to build dedicated AC diagnostic Rayon pool")?,
339        )
340    } else {
341        None
342    };
343
344    let mut predictor = DiagnosticRatePredictor::from_rate_backend(
345        RateBackend::Mixture {
346            spec: Arc::new(spec.clone()),
347        },
348        -1,
349    )?;
350    predictor.begin_stream(data.len())?;
351
352    let trace_file = File::create(&trace_path)
353        .with_context(|| format!("failed to create trace TSV {}", trace_path.display()))?;
354    let mut trace_writer = BufWriter::new(trace_file);
355
356    let mut trace_header = vec![
357        "t".to_string(),
358        "byte_u8".to_string(),
359        "byte_hex".to_string(),
360        "mix_prob".to_string(),
361        "mix_bits".to_string(),
362        "root_weight_entropy_bits".to_string(),
363        "root_top1_id".to_string(),
364        "root_top1_weight".to_string(),
365        "root_top2_id".to_string(),
366        "root_top2_weight".to_string(),
367        "root_top12_margin".to_string(),
368        "root_top1_switched".to_string(),
369        "oracle_best_id".to_string(),
370        "oracle_best_bits".to_string(),
371        "oracle_regret_bits".to_string(),
372        "oracle_best_switched".to_string(),
373    ];
374    for &node_id in &schema.non_root_ids {
375        trace_header.push(format!("n{node_id}__prob"));
376        trace_header.push(format!("n{node_id}__bits"));
377        trace_header.push(format!("n{node_id}__local_weight"));
378        trace_header.push(format!("n{node_id}__effective_weight"));
379    }
380    writeln!(trace_writer, "{}", trace_header.join("\t"))?;
381
382    let mut row_values = Vec::<AcLogLossNodeValue>::with_capacity(schema.non_root_ids.len());
383    let mut node_accum = vec![NodeSummaryAccum::default(); schema.non_root_ids.len()];
384    let mut mix_total_bits = 0.0;
385    let mut oracle_total_bits = 0.0;
386    let mut root_weight_entropy_sum = 0.0;
387    let mut root_top12_margin_sum = 0.0;
388    let mut root_top1_switch_count = 0u64;
389    let mut oracle_switch_count = 0u64;
390    let mut prev_root_top1_id = None;
391    let mut prev_oracle_id = None;
392
393    let mut counter = CountingWriter::default();
394    {
395        let mut encoder = crate::coders::ArithmeticEncoder::new(&mut counter);
396        for (t, &byte) in data.iter().enumerate() {
397            let root_snapshot =
398                predictor.diagnostic_root_snapshot(byte, pool.as_ref(), &mut row_values)?;
399            if row_values.len() != schema.non_root_ids.len() {
400                bail!(
401                    "diagnostic row width mismatch: got {}, expected {}",
402                    row_values.len(),
403                    schema.non_root_ids.len()
404                );
405            }
406
407            let mix_bits = bits_from_prob(root_snapshot.mix_prob);
408            mix_total_bits += mix_bits;
409            root_weight_entropy_sum += root_snapshot.root_weight_entropy_bits;
410            let root_top12_margin = root_snapshot.root_top1_weight - root_snapshot.root_top2_weight;
411            root_top12_margin_sum += root_top12_margin;
412
413            let root_top1_id = root_snapshot
414                .root_top1_child_index
415                .and_then(|idx| schema.root_child_ids.get(idx).copied());
416            let root_top2_id = root_snapshot
417                .root_top2_child_index
418                .and_then(|idx| schema.root_child_ids.get(idx).copied());
419            let root_top1_switched = prev_root_top1_id
420                .zip(root_top1_id)
421                .map(|(prev, curr)| prev != curr)
422                .unwrap_or(false);
423            if root_top1_switched {
424                root_top1_switch_count = root_top1_switch_count.saturating_add(1);
425            }
426            prev_root_top1_id = root_top1_id;
427
428            let mut oracle_best_row_index = 0usize;
429            let mut oracle_best_id = schema.non_root_ids[0];
430            let mut oracle_best_bits = bits_from_prob(row_values[0].prob);
431            for (index, (&node_id, value)) in schema
432                .non_root_ids
433                .iter()
434                .zip(row_values.iter())
435                .enumerate()
436            {
437                let node_bits = bits_from_prob(value.prob);
438                if node_bits < oracle_best_bits {
439                    oracle_best_bits = node_bits;
440                    oracle_best_id = node_id;
441                    oracle_best_row_index = index;
442                }
443                let acc = &mut node_accum[index];
444                acc.total_bits += node_bits;
445                acc.total_local_weight += value.local_weight;
446                acc.total_effective_weight += value.effective_weight;
447            }
448            node_accum[oracle_best_row_index].oracle_win_count = node_accum[oracle_best_row_index]
449                .oracle_win_count
450                .saturating_add(1);
451
452            let oracle_best_switched = prev_oracle_id
453                .zip(Some(oracle_best_id))
454                .map(|(prev, curr)| prev != curr)
455                .unwrap_or(false);
456            if oracle_best_switched {
457                oracle_switch_count = oracle_switch_count.saturating_add(1);
458            }
459            prev_oracle_id = Some(oracle_best_id);
460            oracle_total_bits += oracle_best_bits;
461
462            let mut row = Vec::with_capacity(trace_header.len());
463            row.push(t.to_string());
464            row.push(byte.to_string());
465            row.push(format!("{byte:02X}"));
466            row.push(format_f64(root_snapshot.mix_prob));
467            row.push(format_f64(mix_bits));
468            row.push(format_f64(root_snapshot.root_weight_entropy_bits));
469            row.push(format_optional_usize(root_top1_id));
470            row.push(format_f64(root_snapshot.root_top1_weight));
471            row.push(format_optional_usize(root_top2_id));
472            row.push(format_f64(root_snapshot.root_top2_weight));
473            row.push(format_f64(root_top12_margin));
474            row.push(bool_flag(root_top1_switched).to_string());
475            row.push(oracle_best_id.to_string());
476            row.push(format_f64(oracle_best_bits));
477            row.push(format_f64(mix_bits - oracle_best_bits));
478            row.push(bool_flag(oracle_best_switched).to_string());
479            for value in &row_values {
480                row.push(format_f64(value.prob));
481                row.push(format_f64(bits_from_prob(value.prob)));
482                row.push(format_f64(value.local_weight));
483                row.push(format_f64(value.effective_weight));
484            }
485            writeln!(trace_writer, "{}", row.join("\t"))?;
486
487            predictor.encode_symbol_ac_step(byte, &mut encoder)?;
488        }
489        let _ = encoder.finish()?;
490    }
491    predictor.finish_stream()?;
492    trace_writer.flush()?;
493
494    let positions = data.len();
495    let ac_payload_bits_raw = counter.bytes_written.saturating_mul(8);
496    let oracle_regret_bits = mix_total_bits - oracle_total_bits;
497    let root_weight_entropy_bits_avg = if positions > 0 {
498        root_weight_entropy_sum / (positions as f64)
499    } else {
500        0.0
501    };
502    let root_top12_margin_avg = if positions > 0 {
503        root_top12_margin_sum / (positions as f64)
504    } else {
505        0.0
506    };
507    let coder_overhead_bits = (ac_payload_bits_raw as f64) - mix_total_bits;
508
509    let summary_file = File::create(&summary_path)
510        .with_context(|| format!("failed to create summary TSV {}", summary_path.display()))?;
511    let mut summary_writer = BufWriter::new(summary_file);
512    let mut summary_header = vec![
513        "positions".to_string(),
514        "input_bytes".to_string(),
515        "mix_total_bits".to_string(),
516        "oracle_total_bits".to_string(),
517        "oracle_regret_bits".to_string(),
518        "root_top1_switch_count".to_string(),
519        "oracle_switch_count".to_string(),
520        "root_weight_entropy_bits_avg".to_string(),
521        "root_top12_margin_avg".to_string(),
522        "ac_payload_bits_raw".to_string(),
523        "coder_overhead_bits".to_string(),
524    ];
525    for &node_id in &schema.non_root_ids {
526        summary_header.push(format!("n{node_id}__total_bits"));
527        summary_header.push(format!("n{node_id}__regret_bits"));
528        summary_header.push(format!("n{node_id}__oracle_win_count"));
529        summary_header.push(format!("n{node_id}__avg_local_weight"));
530        summary_header.push(format!("n{node_id}__avg_effective_weight"));
531    }
532    writeln!(summary_writer, "{}", summary_header.join("\t"))?;
533
534    let mut summary_row = vec![
535        positions.to_string(),
536        data.len().to_string(),
537        format_f64(mix_total_bits),
538        format_f64(oracle_total_bits),
539        format_f64(oracle_regret_bits),
540        root_top1_switch_count.to_string(),
541        oracle_switch_count.to_string(),
542        format_f64(root_weight_entropy_bits_avg),
543        format_f64(root_top12_margin_avg),
544        ac_payload_bits_raw.to_string(),
545        format_f64(coder_overhead_bits),
546    ];
547    for acc in &node_accum {
548        summary_row.push(format_f64(acc.total_bits));
549        summary_row.push(format_f64(mix_total_bits - acc.total_bits));
550        summary_row.push(acc.oracle_win_count.to_string());
551        summary_row.push(format_f64(if positions > 0 {
552            acc.total_local_weight / (positions as f64)
553        } else {
554            0.0
555        }));
556        summary_row.push(format_f64(if positions > 0 {
557            acc.total_effective_weight / (positions as f64)
558        } else {
559            0.0
560        }));
561    }
562    writeln!(summary_writer, "{}", summary_row.join("\t"))?;
563    summary_writer.flush()?;
564
565    Ok(AcLogLossRunSummary {
566        trace_path,
567        nodes_path,
568        summary_path,
569        positions,
570        mix_total_bits,
571        oracle_total_bits,
572        ac_payload_bits_raw,
573    })
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579
580    fn test_nested_spec() -> MixtureSpec {
581        MixtureSpec::new(
582            MixtureKind::Switching,
583            vec![
584                MixtureExpertSpec {
585                    name: Some("ctw".to_string()),
586                    log_prior: 0.0,
587                    max_order: -1,
588                    backend: RateBackend::Ctw { depth: 6 },
589                },
590                MixtureExpertSpec {
591                    name: Some("nested".to_string()),
592                    log_prior: -0.1,
593                    max_order: -1,
594                    backend: RateBackend::Mixture {
595                        spec: Arc::new(MixtureSpec::new(
596                            MixtureKind::Bayes,
597                            vec![
598                                MixtureExpertSpec {
599                                    name: Some("fac".to_string()),
600                                    log_prior: 0.0,
601                                    max_order: -1,
602                                    backend: RateBackend::FacCtw {
603                                        base_depth: 5,
604                                        num_percept_bits: 8,
605                                        encoding_bits: 8,
606                                    },
607                                },
608                                MixtureExpertSpec {
609                                    name: Some("ppmd".to_string()),
610                                    log_prior: 0.0,
611                                    max_order: -1,
612                                    backend: RateBackend::Ppmd {
613                                        order: 4,
614                                        memory_mb: 8,
615                                    },
616                                },
617                            ],
618                        )),
619                    },
620                },
621            ],
622        )
623        .with_alpha(0.2)
624    }
625
626    #[test]
627    fn flatten_schema_includes_submixtures_and_descendants_in_preorder() {
628        let schema = flatten_mixture_spec(&test_nested_spec());
629        assert_eq!(schema.nodes.len(), 5);
630        assert_eq!(schema.nodes[0].display_name, "root");
631        assert_eq!(schema.root_child_ids, vec![1, 2]);
632        assert_eq!(schema.non_root_ids, vec![1, 2, 3, 4]);
633        assert_eq!(schema.nodes[2].display_name, "nested");
634        assert!(schema.nodes[2].is_mixture);
635        assert_eq!(schema.nodes[3].parent_id, Some(2));
636        assert_eq!(schema.nodes[4].parent_id, Some(2));
637    }
638
639    #[test]
640    fn diagnostic_snapshot_matches_root_pdf_and_oracle_minimum() {
641        let spec = test_nested_spec();
642        let mut predictor = DiagnosticRatePredictor::from_rate_backend(
643            RateBackend::Mixture {
644                spec: Arc::new(spec.clone()),
645            },
646            -1,
647        )
648        .expect("predictor");
649        let data = b"nested diagnostic payload";
650        predictor.begin_stream(data.len()).expect("begin stream");
651        let mut row_values = Vec::new();
652
653        for &symbol in data {
654            let snapshot = predictor
655                .diagnostic_root_snapshot(symbol, None, &mut row_values)
656                .expect("root snapshot");
657            let root_pdf = predictor.pdf_next().expect("root pdf");
658            let root_prob = root_pdf[symbol as usize];
659            assert!(
660                (snapshot.mix_prob - root_prob).abs() < 1e-8,
661                "snapshot={} root_pdf={}",
662                snapshot.mix_prob,
663                root_prob
664            );
665            let mut oracle_bits = f64::INFINITY;
666            for row in &row_values {
667                let bits = bits_from_prob(row.prob);
668                assert!(
669                    (bits + row.prob.log2()).abs() < 1e-8,
670                    "bits={} prob={}",
671                    bits,
672                    row.prob
673                );
674                oracle_bits = oracle_bits.min(bits);
675            }
676            assert!(
677                oracle_bits <= bits_from_prob(snapshot.mix_prob) + 1e-8,
678                "oracle_bits={} mix_bits={}",
679                oracle_bits,
680                bits_from_prob(snapshot.mix_prob)
681            );
682            predictor.update(symbol).expect("update");
683        }
684        predictor.finish_stream().expect("finish stream");
685    }
686
687    #[test]
688    fn diagnostic_ac_payload_matches_raw_ac_compression_size() {
689        let spec = test_nested_spec();
690        let data = b"payload bits raw diagnostic parity";
691        let stamp = format!(
692            "infotheory_ac_diag_{}_{}",
693            std::process::id(),
694            std::time::SystemTime::now()
695                .duration_since(std::time::UNIX_EPOCH)
696                .expect("clock")
697                .as_nanos()
698        );
699        let prefix = std::env::temp_dir().join(stamp);
700        let summary =
701            run_ac_log_loss_mixture_bytes(data, &spec, &prefix).expect("diagnostic run succeeds");
702        let backend = RateBackend::Mixture {
703            spec: Arc::new(spec),
704        };
705        let encoded = crate::compression::compress_rate_bytes(
706            data,
707            &backend,
708            -1,
709            crate::coders::CoderType::AC,
710            crate::compression::FramingMode::Raw,
711        )
712        .expect("raw ac compression");
713        assert_eq!(summary.ac_payload_bits_raw, (encoded.len() as u64) * 8);
714        let _ = std::fs::remove_file(summary.trace_path);
715        let _ = std::fs::remove_file(summary.nodes_path);
716        let _ = std::fs::remove_file(summary.summary_path);
717    }
718}