Skip to main content

infotheory/aixi/
vm_nyx.rs

1//! High-performance VM-backed AIXI environment using nyx-lite (Firecracker).
2//!
3//! This module provides a VM environment implementation built on top of nyx-lite,
4//! enabling high-frequency snapshot-based resets for fast experimentation (hardware and
5//! guest behavior dependent).
6//!
7//! ## Architecture
8//!
9//! The environment uses Firecracker's KVM-based microVM with nyx-lite's incremental
10//! snapshot and reset capabilities. Communication with the guest occurs via:
11//!
12//! 1. **Shared Memory**: Zero-copy data transfer between host and guest
13//! 2. **Hypercalls**: Control plane communication (snapshot, done, etc.)
14//! 3. **Serial PTY**: Optional console I/O for simpler protocols
15//!
16//! ## Design Principles
17//!
18//! - **Universal**: Not biased towards any specific use case (fuzzing, etc.)
19//! - **High Performance**: Leverages incremental snapshots and dirty page tracking
20//! - **Configurable**: Pluggable reward policies, action sources, observation modes
21//! - **Information-Theoretic**: Built-in support for entropy-based metrics
22
23use crate::aixi::common::{Action, PerceptVal, RandomGenerator, Reward};
24use crate::aixi::environment::Environment;
25#[cfg(feature = "backend-rwkv")]
26use crate::coders::softmax_pdf_inplace;
27#[cfg(feature = "backend-mamba")]
28use crate::mambazip;
29#[cfg(feature = "backend-mamba")]
30use crate::mambazip::Compressor as MambaCompressor;
31use crate::mixture::OnlineBytePredictor;
32use crate::rosaplus::RosaPlus;
33#[cfg(feature = "backend-rwkv")]
34use crate::rwkvzip::Compressor;
35use crate::zpaq_rate::ZpaqRateModel;
36use crate::{
37    RateBackend, cross_entropy_rate_backend, entropy_rate_backend, marginal_entropy_bytes,
38};
39use serde_json::Value;
40use std::borrow::Cow;
41use std::fs::OpenOptions;
42use std::io::Write;
43use std::path::Path;
44use std::sync::Arc;
45use std::time::{Duration, Instant};
46
47// Re-export nyx-lite types for external use
48pub use nyx_lite::mem::SharedMemoryRegion;
49pub use nyx_lite::snapshot::NyxSnapshot;
50pub use nyx_lite::{ExitReason, NyxVM, SharedMemoryPolicy};
51
52// ============================================================================
53// Encoding Types
54// ============================================================================
55
56/// Payload encoding for wire protocol.
57#[derive(Clone, Copy, Debug, Eq, PartialEq)]
58pub enum PayloadEncoding {
59    /// Treat payloads as UTF-8/text bytes.
60    Utf8,
61    /// Treat payloads as hexadecimal text.
62    Hex,
63}
64
65impl PayloadEncoding {
66    /// Parse a payload encoding label.
67    ///
68    /// Accepted values are `utf8`, `text`, and `hex`.
69    #[allow(clippy::should_implement_trait)]
70    pub fn from_str(s: &str) -> Option<Self> {
71        Self::parse(s)
72    }
73
74    /// Parse a payload encoding label.
75    ///
76    /// Accepted values are `utf8`, `text`, and `hex`.
77    pub fn parse(s: &str) -> Option<Self> {
78        match s {
79            "utf8" | "text" => Some(Self::Utf8),
80            "hex" => Some(Self::Hex),
81            _ => None,
82        }
83    }
84
85    /// Decode a wire payload string into raw bytes using this encoding.
86    pub fn decode(self, s: &str) -> anyhow::Result<Vec<u8>> {
87        match self {
88            Self::Utf8 => Ok(s.as_bytes().to_vec()),
89            Self::Hex => hex_decode(s),
90        }
91    }
92
93    /// Encode raw bytes for transport over the configured wire protocol.
94    pub fn encode(self, bytes: &[u8]) -> String {
95        match self {
96            Self::Utf8 => String::from_utf8_lossy(bytes).to_string(),
97            Self::Hex => hex_encode(bytes),
98        }
99    }
100}
101
102impl std::str::FromStr for PayloadEncoding {
103    type Err = &'static str;
104
105    fn from_str(s: &str) -> Result<Self, Self::Err> {
106        Self::parse(s).ok_or("unknown payload encoding")
107    }
108}
109
110fn hex_decode(s: &str) -> anyhow::Result<Vec<u8>> {
111    let mut out = Vec::with_capacity(s.len() / 2);
112    let mut buf = 0u8;
113    let mut high = true;
114    for c in s.bytes() {
115        let v = match c {
116            b'0'..=b'9' => c - b'0',
117            b'a'..=b'f' => c - b'a' + 10,
118            b'A'..=b'F' => c - b'A' + 10,
119            b' ' | b'\n' | b'\r' | b'\t' => continue,
120            _ => return Err(anyhow::anyhow!("invalid hex byte: {}", c as char)),
121        };
122        if high {
123            buf = v << 4;
124            high = false;
125        } else {
126            buf |= v;
127            out.push(buf);
128            high = true;
129        }
130    }
131    if !high {
132        return Err(anyhow::anyhow!("hex string has odd length"));
133    }
134    Ok(out)
135}
136
137fn resolve_relative_path(base: &Path, path: &str) -> String {
138    let p = Path::new(path);
139    if p.is_absolute() {
140        path.to_string()
141    } else {
142        base.join(p).to_string_lossy().to_string()
143    }
144}
145
146fn rewrite_firecracker_config_paths(config_path: &str, raw_json: &str) -> anyhow::Result<String> {
147    let base_dir = Path::new(config_path)
148        .parent()
149        .unwrap_or_else(|| Path::new("."));
150    let mut v: Value = serde_json::from_str(raw_json)?;
151
152    if let Some(boot) = v.get_mut("boot-source") {
153        if let Some(path_val) = boot.get_mut("kernel_image_path")
154            && let Some(path_str) = path_val.as_str()
155        {
156            let resolved = resolve_relative_path(base_dir, path_str);
157            *path_val = Value::String(resolved);
158        }
159        if let Some(path_val) = boot.get_mut("initrd_path")
160            && let Some(path_str) = path_val.as_str()
161        {
162            let resolved = resolve_relative_path(base_dir, path_str);
163            *path_val = Value::String(resolved);
164        }
165    }
166
167    if let Some(drives) = v.get_mut("drives").and_then(|d| d.as_array_mut()) {
168        for drive in drives {
169            if let Some(path_val) = drive.get_mut("path_on_host")
170                && let Some(path_str) = path_val.as_str()
171            {
172                let resolved = resolve_relative_path(base_dir, path_str);
173                *path_val = Value::String(resolved);
174            }
175        }
176    }
177
178    Ok(serde_json::to_string(&v)?)
179}
180
181fn hex_encode(bytes: &[u8]) -> String {
182    let mut s = String::with_capacity(bytes.len() * 2);
183    for b in bytes {
184        s.push(hex_digit(b >> 4));
185        s.push(hex_digit(b & 0x0F));
186    }
187    s
188}
189
190fn hex_digit(v: u8) -> char {
191    match v {
192        0..=9 => (b'0' + v) as char,
193        _ => (b'a' + (v - 10)) as char,
194    }
195}
196
197// ============================================================================
198// Guest Communication Protocol
199// ============================================================================
200
201/// Hypercall identifiers (must match guest implementation).
202/// These are exported for use by custom guest programs.
203#[allow(dead_code)]
204pub const HYPERCALL_EXECDONE: u64 = 0x656e6f6463657865; // "execdone"
205/// Guest requested host-side snapshot operation.
206#[allow(dead_code)]
207pub const HYPERCALL_SNAPSHOT: u64 = 0x746f687370616e73; // "snapshot"
208/// Guest announced nyx-lite protocol/version handshake.
209#[allow(dead_code)]
210pub const HYPERCALL_NYX_LITE: u64 = 0x6574696c2d78796e; // "nyx-lite"
211/// Guest requested shared memory initialization/refresh.
212#[allow(dead_code)]
213pub const HYPERCALL_SHAREMEM: u64 = 0x6d656d6572616873; // "sharemem"
214/// Guest emitted a debug-print hypercall payload.
215#[allow(dead_code)]
216pub const HYPERCALL_DBGPRINT: u64 = 0x746e697270676264; // "dbgprint"
217
218const SHARED_ACTION_LEN_OFFSET: u64 = 0;
219const SHARED_RESP_LEN_OFFSET: u64 = 8;
220const SHARED_PAYLOAD_OFFSET: u64 = 16;
221
222/// Protocol configuration for structured communication.
223#[derive(Clone, Debug)]
224pub struct NyxProtocolConfig {
225    /// Prefix for action messages.
226    pub action_prefix: String,
227    /// Suffix for action messages.
228    pub action_suffix: String,
229    /// Prefix for observation responses.
230    pub obs_prefix: String,
231    /// Prefix for reward responses.
232    pub rew_prefix: String,
233    /// Prefix for done indicator.
234    pub done_prefix: String,
235    /// Prefix for data payloads.
236    pub data_prefix: String,
237    /// Wire encoding for payloads.
238    pub wire_encoding: PayloadEncoding,
239}
240
241impl Default for NyxProtocolConfig {
242    fn default() -> Self {
243        Self {
244            action_prefix: "ACT ".to_string(),
245            action_suffix: "\n".to_string(),
246            obs_prefix: "OBS ".to_string(),
247            rew_prefix: "REW ".to_string(),
248            done_prefix: "DONE ".to_string(),
249            data_prefix: "DATA ".to_string(),
250            wire_encoding: PayloadEncoding::Hex,
251        }
252    }
253}
254
255// ============================================================================
256// Action Configuration
257// ============================================================================
258
259/// A single action specification.
260#[derive(Clone, Debug)]
261pub struct NyxActionSpec {
262    /// Optional human-readable name.
263    pub name: Option<String>,
264    /// Raw payload bytes to send.
265    pub payload: Vec<u8>,
266}
267
268/// Fuzzing mutator types.
269#[derive(Clone, Debug)]
270pub enum FuzzMutator {
271    /// Flip one random bit.
272    FlipBit,
273    /// Flip one full byte.
274    FlipByte,
275    /// Insert a random byte at a random position.
276    InsertByte,
277    /// Delete one random byte.
278    DeleteByte,
279    /// Splice bytes from an existing seed input.
280    SpliceSeed,
281    /// Replace the working input with a seed input.
282    ResetSeed,
283    /// Apply a short sequence of random mutations.
284    Havoc,
285}
286
287/// Fuzzing configuration for action generation.
288#[derive(Clone, Debug)]
289pub struct NyxFuzzConfig {
290    /// Corpus used for seed/reset/splice operations.
291    pub seeds: Vec<Vec<u8>>,
292    /// Mutator set available for action generation.
293    pub mutators: Vec<FuzzMutator>,
294    /// Minimum generated action length.
295    pub min_len: usize,
296    /// Maximum generated action length.
297    pub max_len: usize,
298    /// Optional dictionary tokens for insertion/splicing.
299    pub dictionary: Vec<Vec<u8>>,
300    /// Deterministic RNG seed for mutation sampling.
301    pub rng_seed: u64,
302}
303
304/// Source of actions for the environment.
305#[derive(Clone, Debug)]
306pub enum NyxActionSource {
307    /// Fixed set of action payloads.
308    Literal(Vec<NyxActionSpec>),
309    /// Mutation-based action generation.
310    Fuzz(NyxFuzzConfig),
311}
312
313// ============================================================================
314// Observation Configuration
315// ============================================================================
316
317/// How observations are derived from guest output.
318#[derive(Clone, Copy, Debug)]
319pub enum NyxObservationPolicy {
320    /// Parse structured OBS/REW/DONE messages from guest.
321    FromGuest,
322    /// Hash raw output to derive observation.
323    OutputHash,
324    /// Use raw output bytes as observation stream.
325    RawOutput,
326    /// Use shared memory contents as observation.
327    SharedMemory,
328}
329
330/// Stream normalization mode.
331#[derive(Clone, Copy, Debug)]
332pub enum NyxObservationStreamMode {
333    /// Pad short streams, truncate long ones.
334    PadTruncate,
335    /// Only pad short streams.
336    Pad,
337    /// Only truncate long streams.
338    Truncate,
339}
340
341// ============================================================================
342// Reward Configuration
343// ============================================================================
344
345/// How rewards are computed.
346#[derive(Clone)]
347pub enum NyxRewardPolicy {
348    /// Parse reward from guest response.
349    FromGuest,
350    /// Pattern matching on output.
351    Pattern {
352        /// Substring/pattern tested against guest output.
353        pattern: String,
354        /// Reward returned when the pattern does not match.
355        base_reward: i64,
356        /// Additional reward added when the pattern matches.
357        bonus_reward: i64,
358    },
359    /// Custom reward function (callback-based).
360    Custom(Arc<dyn Fn(&NyxStepResult) -> Reward + Send + Sync>),
361}
362
363/// Optional reward shaping (additive to base reward).
364#[derive(Clone, Debug)]
365pub enum NyxRewardShaping {
366    /// Entropy reduction vs baseline.
367    EntropyReduction {
368        /// Reference bytes used as baseline data distribution.
369        baseline_bytes: Vec<u8>,
370        /// Max order passed to entropy estimators.
371        max_order: i64,
372        /// Scaling factor applied to the shaping term.
373        scale: f64,
374        /// Optional additive bonus when guest crashes.
375        crash_bonus: Option<i64>,
376        /// Optional additive bonus when guest times out.
377        timeout_bonus: Option<i64>,
378    },
379    /// Entropy of trace data (online learning).
380    TraceEntropy {
381        /// Max order passed to trace entropy estimation.
382        max_order: i64,
383        /// Scaling factor applied to the shaping term.
384        scale: f64,
385        /// If true, normalize by trace length.
386        normalize: bool,
387    },
388}
389
390impl std::fmt::Debug for NyxRewardPolicy {
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        match self {
393            Self::FromGuest => write!(f, "FromGuest"),
394            Self::Pattern {
395                pattern,
396                base_reward,
397                bonus_reward,
398            } => f
399                .debug_struct("Pattern")
400                .field("pattern", pattern)
401                .field("base_reward", base_reward)
402                .field("bonus_reward", bonus_reward)
403                .finish(),
404            Self::Custom(_) => write!(f, "Custom(<fn>)"),
405        }
406    }
407}
408
409// ============================================================================
410// Action Filtering
411// ============================================================================
412
413/// Information-theoretic action filtering.
414#[derive(Clone, Debug)]
415pub struct NyxActionFilter {
416    /// Minimum entropy threshold.
417    pub min_entropy: Option<f64>,
418    /// Maximum entropy threshold.
419    pub max_entropy: Option<f64>,
420    /// Minimum intrinsic dependence.
421    pub min_intrinsic_dependence: Option<f64>,
422    /// Minimum novelty (cross-entropy vs prior).
423    pub min_novelty: Option<f64>,
424    /// Prior corpus for novelty computation.
425    pub novelty_prior: Option<Vec<u8>>,
426    /// Max order for entropy estimation.
427    pub max_order: i64,
428    /// Reward to assign when action is rejected.
429    pub reject_reward: Option<i64>,
430}
431
432// ============================================================================
433// Trace Configuration
434// ============================================================================
435
436/// Configuration for trace collection and analysis.
437#[derive(Clone, Debug)]
438pub struct NyxTraceConfig {
439    /// Shared memory region name for trace data.
440    pub shared_region_name: Option<String>,
441    /// Maximum bytes to collect per step.
442    pub max_bytes: usize,
443    /// Reset trace model on episode boundary.
444    pub reset_on_episode: bool,
445}
446
447// ============================================================================
448// Main Configuration
449// ============================================================================
450
451/// Complete configuration for the nyx-lite VM environment.
452#[derive(Clone)]
453pub struct NyxVmConfig {
454    /// Path to Firecracker JSON config.
455    pub firecracker_config: String,
456    /// Instance ID for the VM.
457    pub instance_id: String,
458
459    // Shared memory configuration
460    /// Name of the shared memory region for communication.
461    pub shared_region_name: String,
462    /// Size of the shared memory region.
463    pub shared_region_size: usize,
464    /// Shared memory policy (snapshot vs preserve).
465    pub shared_memory_policy: SharedMemoryPolicy,
466
467    // Timing configuration
468    /// Timeout for each step.
469    pub step_timeout: Duration,
470    /// Timeout for initial boot.
471    pub boot_timeout: Duration,
472
473    // Episode configuration
474    /// Number of steps per episode.
475    pub episode_steps: usize,
476    /// Cost subtracted from reward each step.
477    pub step_cost: i64,
478
479    // Observation configuration
480    /// Observation derivation policy.
481    pub observation_policy: NyxObservationPolicy,
482    /// Bits per observation symbol.
483    pub observation_bits: usize,
484    /// Number of observation symbols per action.
485    pub observation_stream_len: usize,
486    /// Stream normalization mode.
487    pub observation_stream_mode: NyxObservationStreamMode,
488    /// Padding byte for short streams.
489    pub observation_pad_byte: u8,
490
491    // Reward configuration
492    /// Bits for reward encoding.
493    pub reward_bits: usize,
494    /// Reward computation policy.
495    pub reward_policy: NyxRewardPolicy,
496    /// Optional reward shaping (additive; non-canonical).
497    pub reward_shaping: Option<NyxRewardShaping>,
498
499    // Action configuration
500    /// Source of actions.
501    pub action_source: NyxActionSource,
502    /// Optional action filter.
503    pub action_filter: Option<NyxActionFilter>,
504
505    // Protocol configuration
506    /// Wire protocol for structured communication.
507    pub protocol: NyxProtocolConfig,
508
509    // Statistics backend
510    /// Backend for entropy estimation.
511    pub stats_backend: RateBackend,
512
513    // Trace configuration
514    /// Optional trace collection.
515    pub trace: Option<NyxTraceConfig>,
516
517    // Debug mode
518    /// Enable verbose VM/protocol diagnostics.
519    pub debug_mode: bool,
520
521    // Crash logging
522    /// Path to log crashes/interesting behaviors (JSONL format).
523    pub crash_log: Option<String>,
524}
525
526impl Default for NyxVmConfig {
527    fn default() -> Self {
528        Self {
529            firecracker_config: String::new(),
530            instance_id: "aixi-nyx".to_string(),
531            shared_region_name: "shared".to_string(),
532            shared_region_size: 4096,
533            shared_memory_policy: SharedMemoryPolicy::Snapshot,
534            step_timeout: Duration::from_millis(100),
535            boot_timeout: Duration::from_secs(30),
536            episode_steps: 100,
537            step_cost: 0,
538            observation_policy: NyxObservationPolicy::SharedMemory,
539            observation_bits: 8,
540            observation_stream_len: 64,
541            observation_stream_mode: NyxObservationStreamMode::PadTruncate,
542            observation_pad_byte: 0,
543            reward_bits: 8,
544            reward_policy: NyxRewardPolicy::FromGuest,
545            reward_shaping: None,
546            action_source: NyxActionSource::Literal(vec![]),
547            action_filter: None,
548            protocol: NyxProtocolConfig::default(),
549            stats_backend: RateBackend::default(),
550            trace: None,
551            debug_mode: false,
552            crash_log: None,
553        }
554    }
555}
556
557// ============================================================================
558// Step Result
559// ============================================================================
560
561/// Result of a single environment step.
562#[derive(Clone, Debug)]
563pub struct NyxStepResult {
564    /// Exit reason from the VM.
565    pub exit_reason: NyxExitKind,
566    /// Raw output data from guest.
567    pub output: Vec<u8>,
568    /// Parsed observation (if any).
569    pub parsed_obs: Option<u64>,
570    /// Parsed reward (if any).
571    pub parsed_rew: Option<i64>,
572    /// Done flag.
573    pub done: bool,
574    /// Trace data (if collected).
575    pub trace_data: Vec<u8>,
576    /// Shared memory contents snapshot.
577    pub shared_memory: Vec<u8>,
578}
579
580/// Simplified exit reason categories.
581#[derive(Clone, Debug)]
582pub enum NyxExitKind {
583    /// Guest terminated normally with an application-defined code.
584    ExecDone(u64),
585    /// Step timed out before a terminal signal/response.
586    Timeout,
587    /// VM reported a shutdown event.
588    Shutdown,
589    /// Raw hypercall event with integer arguments.
590    Hypercall {
591        /// Hypercall identifier/magic value.
592        code: u64,
593        /// Hypercall argument 1.
594        arg1: u64,
595        /// Hypercall argument 2.
596        arg2: u64,
597        /// Hypercall argument 3.
598        arg3: u64,
599        /// Hypercall argument 4.
600        arg4: u64,
601    },
602    /// Debug string emitted by guest/host bridge.
603    DebugPrint(String),
604    /// Breakpoint/trap-like stop event.
605    Breakpoint,
606    /// Uncategorized exit event represented as text.
607    Other(String),
608}
609
610impl From<ExitReason> for NyxExitKind {
611    fn from(reason: ExitReason) -> Self {
612        match reason {
613            ExitReason::ExecDone(code) => Self::ExecDone(code),
614            ExitReason::Timeout => Self::Timeout,
615            ExitReason::Shutdown => Self::Shutdown,
616            ExitReason::Hypercall(r8, r9, r10, r11, r12) => Self::Hypercall {
617                code: r8,
618                arg1: r9,
619                arg2: r10,
620                arg3: r11,
621                arg4: r12,
622            },
623            ExitReason::DebugPrint(s) => Self::DebugPrint(s),
624            ExitReason::Breakpoint => Self::Breakpoint,
625            ExitReason::RequestSnapshot => Self::Other("RequestSnapshot".to_string()),
626            ExitReason::SharedMem(name, _, _) => Self::Other(format!("SharedMem({})", name)),
627            ExitReason::SingleStep => Self::Other("SingleStep".to_string()),
628            ExitReason::Interrupted => Self::Other("Interrupted".to_string()),
629            ExitReason::HWBreakpoint(n) => Self::Other(format!("HWBreakpoint({})", n)),
630            ExitReason::BadMemoryAccess(_) => Self::Other("BadMemoryAccess".to_string()),
631        }
632    }
633}
634
635// ============================================================================
636// Trace Model
637// ============================================================================
638
639/// Predictive model for trace-based reward computation.
640enum TraceModel {
641    Rosa {
642        model: RosaPlus,
643        max_order: i64,
644    },
645    Ctw {
646        tree: crate::ctw::ContextTree,
647    },
648    FacCtw {
649        tree: crate::ctw::FacContextTree,
650        bits_per_symbol: usize,
651    },
652    #[cfg(feature = "backend-mamba")]
653    Mamba {
654        compressor: MambaCompressor,
655        primed: bool,
656    },
657    Rwkv7 {
658        compressor: Compressor,
659        primed: bool,
660    },
661    Zpaq {
662        model: ZpaqRateModel,
663    },
664    Mixture {
665        backend: RateBackend,
666        model: crate::mixture::RateBackendPredictor,
667    },
668}
669
670impl TraceModel {
671    fn predictor_backed(backend: RateBackend) -> Self {
672        let mut model =
673            crate::mixture::RateBackendPredictor::from_backend(backend.clone(), -1, 2f64.powi(-24));
674        model
675            .begin_stream(None)
676            .unwrap_or_else(|e| panic!("predictor-backed stream init failed: {e}"));
677        TraceModel::Mixture { backend, model }
678    }
679
680    fn new(backend: &RateBackend, max_order: i64) -> Self {
681        match backend {
682            RateBackend::RosaPlus => {
683                let mut model = RosaPlus::new(max_order, false, 0, 42);
684                model.build_lm_full_bytes_no_finalize_endpos();
685                TraceModel::Rosa { model, max_order }
686            }
687            #[cfg(feature = "backend-mamba")]
688            RateBackend::Mamba { model } => {
689                let compressor = MambaCompressor::new_from_model(model.clone());
690                TraceModel::Mamba {
691                    compressor,
692                    primed: false,
693                }
694            }
695            #[cfg(feature = "backend-mamba")]
696            RateBackend::MambaMethod { method } => {
697                let compressor = MambaCompressor::new_from_method(method)
698                    .unwrap_or_else(|e| panic!("invalid mamba method for vm trace model: {e}"));
699                TraceModel::Mamba {
700                    compressor,
701                    primed: false,
702                }
703            }
704            RateBackend::Rwkv7 { model } => {
705                let compressor = Compressor::new_from_model(model.clone());
706                TraceModel::Rwkv7 {
707                    compressor,
708                    primed: false,
709                }
710            }
711            RateBackend::Rwkv7Method { method } => {
712                let compressor = Compressor::new_from_method(method)
713                    .unwrap_or_else(|e| panic!("invalid rwkv7 method for vm trace model: {e}"));
714                TraceModel::Rwkv7 {
715                    compressor,
716                    primed: false,
717                }
718            }
719            RateBackend::Zpaq { method } => TraceModel::Zpaq {
720                model: ZpaqRateModel::new(method.clone(), 2f64.powi(-24)),
721            },
722            RateBackend::Mixture { .. }
723            | RateBackend::Particle { .. }
724            | RateBackend::Match { .. }
725            | RateBackend::SparseMatch { .. }
726            | RateBackend::Ppmd { .. }
727            | RateBackend::Sequitur { .. }
728            | RateBackend::Calibrated { .. } => TraceModel::predictor_backed(backend.clone()),
729            RateBackend::Ctw { depth } => TraceModel::Ctw {
730                tree: crate::ctw::ContextTree::new(*depth),
731            },
732            RateBackend::FacCtw {
733                base_depth,
734                num_percept_bits: _,
735                encoding_bits,
736            } => {
737                let bits_per_symbol = (*encoding_bits).clamp(1, 8);
738                TraceModel::FacCtw {
739                    tree: crate::ctw::FacContextTree::new(*base_depth, bits_per_symbol),
740                    bits_per_symbol,
741                }
742            }
743        }
744    }
745
746    fn reset(&mut self) {
747        match self {
748            TraceModel::Rosa { model, max_order } => {
749                let mut fresh = RosaPlus::new(*max_order, false, 0, 42);
750                fresh.build_lm_full_bytes_no_finalize_endpos();
751                *model = fresh;
752            }
753            TraceModel::Ctw { tree } => tree.clear(),
754            TraceModel::FacCtw { tree, .. } => tree.clear(),
755            #[cfg(feature = "backend-mamba")]
756            TraceModel::Mamba { compressor, primed } => {
757                compressor.state.reset();
758                *primed = false;
759            }
760            TraceModel::Rwkv7 { compressor, primed } => {
761                compressor.state.reset();
762                *primed = false;
763            }
764            TraceModel::Zpaq { model } => {
765                model.reset();
766            }
767            TraceModel::Mixture { backend, model } => {
768                *model = crate::mixture::RateBackendPredictor::from_backend(
769                    backend.clone(),
770                    -1,
771                    2f64.powi(-24),
772                );
773                model
774                    .begin_stream(None)
775                    .unwrap_or_else(|e| panic!("mixture stream init failed: {e}"));
776            }
777        }
778    }
779
780    /// Update the model with new data and return the surprise (bits).
781    fn update_and_score(&mut self, data: &[u8]) -> f64 {
782        if data.is_empty() {
783            return 0.0;
784        }
785        match self {
786            TraceModel::Rosa { model, .. } => {
787                let mut bits = 0.0;
788                for &b in data {
789                    let p = model.prob_for_last(b as u32).max(1e-12);
790                    bits -= p.log2();
791                    model.train_byte(b);
792                }
793                bits
794            }
795            TraceModel::Ctw { tree } => {
796                let log_before = tree.get_log_block_probability();
797                for &b in data {
798                    for i in (0..8).rev() {
799                        tree.update(((b >> i) & 1) == 1);
800                    }
801                }
802                let log_after = tree.get_log_block_probability();
803                let log_delta = log_after - log_before;
804                -log_delta / std::f64::consts::LN_2
805            }
806            TraceModel::FacCtw {
807                tree,
808                bits_per_symbol,
809            } => {
810                let log_before = tree.get_log_block_probability();
811                for &b in data {
812                    for i in 0..*bits_per_symbol {
813                        tree.update(((b >> i) & 1) == 1, i);
814                    }
815                }
816                let log_after = tree.get_log_block_probability();
817                let log_delta = log_after - log_before;
818                -log_delta / std::f64::consts::LN_2
819            }
820            #[cfg(feature = "backend-mamba")]
821            TraceModel::Mamba { compressor, primed } => {
822                if !*primed {
823                    let bias = compressor.online_bias_snapshot();
824                    let logits =
825                        compressor
826                            .model
827                            .forward(&mut compressor.scratch, 0, &mut compressor.state);
828                    mambazip::Compressor::logits_to_pdf(
829                        logits,
830                        bias.as_deref(),
831                        &mut compressor.pdf_buffer,
832                    );
833                    *primed = true;
834                }
835                let mut bits = 0.0;
836                for &b in data {
837                    let p = compressor.pdf_buffer[b as usize].max(1e-12);
838                    bits -= p.log2();
839                    let bias = compressor.online_bias_snapshot();
840                    let logits = compressor.model.forward(
841                        &mut compressor.scratch,
842                        b as u32,
843                        &mut compressor.state,
844                    );
845                    mambazip::Compressor::logits_to_pdf(
846                        logits,
847                        bias.as_deref(),
848                        &mut compressor.pdf_buffer,
849                    );
850                }
851                bits
852            }
853            TraceModel::Rwkv7 { compressor, primed } => {
854                if !*primed {
855                    let vocab_size = compressor.vocab_size();
856                    let logits =
857                        compressor
858                            .model
859                            .forward(&mut compressor.scratch, 0, &mut compressor.state);
860                    softmax_pdf_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
861                    *primed = true;
862                }
863                let mut bits = 0.0;
864                let vocab_size = compressor.vocab_size();
865                for &b in data {
866                    let p = compressor.pdf_buffer[b as usize].max(1e-12);
867                    bits -= p.log2();
868                    let logits = compressor.model.forward(
869                        &mut compressor.scratch,
870                        b as u32,
871                        &mut compressor.state,
872                    );
873                    softmax_pdf_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
874                }
875                bits
876            }
877            TraceModel::Zpaq { model } => model.update_and_score(data),
878            TraceModel::Mixture { model, .. } => {
879                let mut bits = 0.0;
880                for &b in data {
881                    let logp = model.log_prob(b);
882                    bits -= logp / std::f64::consts::LN_2;
883                    model.update(b);
884                }
885                bits
886            }
887        }
888    }
889}
890
891// ============================================================================
892// Fuzz State
893// ============================================================================
894
895struct FuzzState {
896    current: Vec<u8>,
897    rng: RandomGenerator,
898}
899
900// ============================================================================
901// NyxVmEnvironment
902// ============================================================================
903
904/// High-performance VM environment using nyx-lite.
905pub struct NyxVmEnvironment {
906    /// Configuration.
907    config: NyxVmConfig,
908    /// The nyx-lite VM instance.
909    vm: NyxVM,
910    /// Base snapshot for episode resets.
911    base_snapshot: Option<Arc<NyxSnapshot>>,
912    /// Shared memory virtual address in guest.
913    shared_vaddr: Option<u64>,
914    /// CR3 used when shared memory was registered.
915    shared_cr3: Option<u64>,
916    /// Trace model for entropy-based rewards.
917    trace_model: Option<TraceModel>,
918    /// Baseline entropy for entropy reduction rewards.
919    baseline_entropy: Option<f64>,
920    /// Effective reward shaping policy (additive).
921    reward_shaping: Option<NyxRewardShaping>,
922    /// Fuzzing state.
923    fuzz_state: Option<FuzzState>,
924
925    // Current step state
926    /// Current observation.
927    obs: PerceptVal,
928    /// Current reward.
929    rew: Reward,
930    /// Current observation stream.
931    obs_stream: Vec<PerceptVal>,
932    /// Step within current episode.
933    step_in_episode: usize,
934    /// Whether the environment needs reset.
935    needs_reset: bool,
936    /// Whether the VM has been initialized.
937    initialized: bool,
938}
939
940impl NyxVmEnvironment {
941    /// Creates a new NyxVmEnvironment with the given configuration.
942    pub fn new(config: NyxVmConfig) -> anyhow::Result<Self> {
943        // Validate configuration
944        if config.firecracker_config.is_empty() {
945            return Err(anyhow::anyhow!("firecracker_config path must be set"));
946        }
947        if config.episode_steps == 0 {
948            return Err(anyhow::anyhow!("episode_steps must be > 0"));
949        }
950        if matches!(config.observation_policy, NyxObservationPolicy::RawOutput)
951            && config.observation_stream_len == 0
952        {
953            return Err(anyhow::anyhow!(
954                "observation_stream_len must be > 0 for RawOutput policy"
955            ));
956        }
957
958        // Load Firecracker config and resolve relative paths
959        let fc_config_raw = std::fs::read_to_string(&config.firecracker_config)
960            .map_err(|e| anyhow::anyhow!("Failed to read firecracker config: {}", e))?;
961        let fc_config =
962            rewrite_firecracker_config_paths(&config.firecracker_config, &fc_config_raw)
963                .map_err(|e| anyhow::anyhow!("Failed to parse firecracker config: {}", e))?;
964
965        // Create the VM
966        let vm = NyxVM::new(config.instance_id.clone(), &fc_config);
967
968        // Initialize reward shaping
969        let reward_shaping = config.reward_shaping.clone();
970
971        if matches!(reward_shaping, Some(NyxRewardShaping::TraceEntropy { .. }))
972            && config.trace.is_none()
973        {
974            return Err(anyhow::anyhow!(
975                "vm_trace must be configured for vm_reward_shaping.mode=trace-entropy"
976            ));
977        }
978
979        // Initialize trace model if needed
980        let trace_model = match &reward_shaping {
981            Some(NyxRewardShaping::TraceEntropy { max_order, .. }) => {
982                Some(TraceModel::new(&config.stats_backend, *max_order))
983            }
984            _ => None,
985        };
986
987        // Compute baseline entropy if needed
988        let baseline_entropy = match &reward_shaping {
989            Some(NyxRewardShaping::EntropyReduction {
990                baseline_bytes,
991                max_order,
992                ..
993            }) => {
994                let h = if *max_order == 0 {
995                    marginal_entropy_bytes(baseline_bytes)
996                } else {
997                    entropy_rate_backend(baseline_bytes, *max_order, &config.stats_backend)
998                };
999                Some(h)
1000            }
1001            _ => None,
1002        };
1003
1004        // Initialize fuzz state if needed
1005        let fuzz_state = match &config.action_source {
1006            NyxActionSource::Fuzz(fuzz) => {
1007                if fuzz.seeds.is_empty() {
1008                    return Err(anyhow::anyhow!("Fuzz mode requires at least one seed"));
1009                }
1010                if fuzz.mutators.is_empty() {
1011                    return Err(anyhow::anyhow!("Fuzz mode requires at least one mutator"));
1012                }
1013                let seed = fuzz.seeds[0].clone();
1014                Some(FuzzState {
1015                    current: seed,
1016                    rng: RandomGenerator::new().fork_with(fuzz.rng_seed),
1017                })
1018            }
1019            NyxActionSource::Literal(actions) => {
1020                if actions.is_empty() {
1021                    return Err(anyhow::anyhow!("Literal mode requires at least one action"));
1022                }
1023                None
1024            }
1025        };
1026
1027        let mut env = Self {
1028            config,
1029            vm,
1030            base_snapshot: None,
1031            shared_vaddr: None,
1032            shared_cr3: None,
1033            trace_model,
1034            baseline_entropy,
1035            reward_shaping,
1036            fuzz_state,
1037            obs: 0,
1038            rew: 0,
1039            obs_stream: Vec::new(),
1040            step_in_episode: 0,
1041            needs_reset: true,
1042            initialized: false,
1043        };
1044
1045        // Boot and initialize
1046        env.initialize()?;
1047
1048        Ok(env)
1049    }
1050
1051    /// Initializes the VM by booting to the snapshot point.
1052    fn initialize(&mut self) -> anyhow::Result<()> {
1053        if self.initialized {
1054            return Ok(());
1055        }
1056
1057        if self.config.debug_mode {
1058            eprintln!("[NyxVm] Booting VM...");
1059        }
1060
1061        // Run until we get the shared memory registration
1062        let start = Instant::now();
1063        loop {
1064            if start.elapsed() > self.config.boot_timeout {
1065                return Err(anyhow::anyhow!("Boot timeout waiting for shared memory"));
1066            }
1067
1068            let exit = self.vm.run(Duration::from_secs(1));
1069            match exit {
1070                ExitReason::SharedMem(name, vaddr, size) => {
1071                    if self.config.debug_mode {
1072                        eprintln!(
1073                            "[NyxVm] Shared memory registered: {} @ {:#x} ({} bytes)",
1074                            name, vaddr, size
1075                        );
1076                    }
1077                    if name.trim_end_matches('\0') == self.config.shared_region_name {
1078                        self.shared_vaddr = Some(vaddr);
1079                        self.shared_cr3 = Some(self.vm.sregs().cr3);
1080                        // Register the shared region with the configured policy
1081                        let _ = self.vm.register_shared_region_current(
1082                            vaddr,
1083                            size,
1084                            self.config.shared_memory_policy,
1085                        );
1086                        break;
1087                    }
1088                }
1089                ExitReason::DebugPrint(msg) => {
1090                    if self.config.debug_mode {
1091                        eprintln!("[NyxVm] Guest: {}", msg);
1092                    }
1093                }
1094                ExitReason::Shutdown => {
1095                    return Err(anyhow::anyhow!("VM shut down during boot"));
1096                }
1097                _ => {
1098                    if self.config.debug_mode {
1099                        eprintln!("[NyxVm] Boot exit: {:?}", exit);
1100                    }
1101                    // Continue waiting
1102                }
1103            }
1104        }
1105
1106        // Continue running until snapshot request
1107        loop {
1108            if start.elapsed() > self.config.boot_timeout {
1109                return Err(anyhow::anyhow!("Boot timeout waiting for snapshot request"));
1110            }
1111
1112            let exit = self.vm.run(Duration::from_secs(1));
1113            match exit {
1114                ExitReason::RequestSnapshot => {
1115                    if self.config.debug_mode {
1116                        eprintln!("[NyxVm] Taking base snapshot...");
1117                    }
1118                    self.base_snapshot = Some(self.vm.take_base_snapshot());
1119                    break;
1120                }
1121                ExitReason::DebugPrint(msg) => {
1122                    if self.config.debug_mode {
1123                        eprintln!("[NyxVm] Guest: {}", msg);
1124                    }
1125                }
1126                ExitReason::Shutdown => {
1127                    return Err(anyhow::anyhow!("VM shut down before snapshot"));
1128                }
1129                _ => {
1130                    if self.config.debug_mode {
1131                        eprintln!("[NyxVm] Snapshot wait exit: {:?}", exit);
1132                    }
1133                    // Continue waiting
1134                }
1135            }
1136        }
1137
1138        if self.config.debug_mode {
1139            eprintln!("[NyxVm] Initialization complete");
1140        }
1141
1142        self.initialized = true;
1143        self.needs_reset = false;
1144        Ok(())
1145    }
1146
1147    /// Resets to the base snapshot.
1148    pub fn reset(&mut self) -> anyhow::Result<()> {
1149        let snapshot = self
1150            .base_snapshot
1151            .as_ref()
1152            .ok_or_else(|| anyhow::anyhow!("No base snapshot available"))?
1153            .clone();
1154
1155        self.vm.apply_snapshot(&snapshot);
1156
1157        // Reset trace model if configured
1158        if let Some(trace_cfg) = &self.config.trace
1159            && trace_cfg.reset_on_episode
1160            && let Some(model) = &mut self.trace_model
1161        {
1162            model.reset();
1163        }
1164
1165        self.step_in_episode = 0;
1166        self.needs_reset = false;
1167
1168        Ok(())
1169    }
1170
1171    /// Writes action data to shared memory.
1172    fn write_action_to_shared_memory(&mut self, payload: &[u8]) -> anyhow::Result<()> {
1173        let vaddr = self
1174            .shared_vaddr
1175            .ok_or_else(|| anyhow::anyhow!("Shared memory not initialized"))?;
1176        let cr3 = self
1177            .shared_cr3
1178            .ok_or_else(|| anyhow::anyhow!("Shared memory CR3 not initialized"))?;
1179        let process = self.vm.process_memory(cr3);
1180
1181        // Ensure guest has cleared the previous message length to avoid races.
1182        let wait_start = Instant::now();
1183        loop {
1184            let cur_len = process
1185                .read_u64(vaddr + SHARED_ACTION_LEN_OFFSET)
1186                .unwrap_or(0);
1187            if cur_len == 0 {
1188                break;
1189            }
1190            if wait_start.elapsed() > self.config.step_timeout {
1191                return Err(anyhow::anyhow!("shared buffer busy (len={cur_len})"));
1192            }
1193            std::thread::yield_now();
1194        }
1195
1196        // Write length as first 8 bytes (u64 LE)
1197        let len = payload.len() as u64;
1198        process
1199            .write_u64(vaddr + SHARED_ACTION_LEN_OFFSET, len)
1200            .map_err(|e| anyhow::anyhow!("write len failed: {e}"))?;
1201        let _ = process.write_u64(vaddr + SHARED_RESP_LEN_OFFSET, 0);
1202
1203        // Write payload starting at offset 8
1204        let max_len = self
1205            .config
1206            .shared_region_size
1207            .saturating_sub(SHARED_PAYLOAD_OFFSET as usize);
1208        let write_len = payload.len().min(max_len);
1209        if write_len > 0 {
1210            let _ = process
1211                .write_bytes(vaddr + SHARED_PAYLOAD_OFFSET, &payload[..write_len])
1212                .map_err(|e| anyhow::anyhow!("write payload failed: {e}"))?;
1213        }
1214
1215        if self.config.debug_mode {
1216            let verify = process
1217                .read_u64(vaddr + SHARED_ACTION_LEN_OFFSET)
1218                .unwrap_or(0) as usize;
1219            eprintln!(
1220                "[NyxVm] Wrote action len={}, verified len={}",
1221                write_len, verify
1222            );
1223        }
1224
1225        Ok(())
1226    }
1227
1228    /// Reads response from shared memory.
1229    fn read_shared_memory(&self) -> Vec<u8> {
1230        let Some(vaddr) = self.shared_vaddr else {
1231            return Vec::new();
1232        };
1233        let Some(cr3) = self.shared_cr3 else {
1234            return Vec::new();
1235        };
1236        let process = self.vm.process_memory(cr3);
1237
1238        // Read length from first 8 bytes
1239        let len = process
1240            .read_u64(vaddr + SHARED_RESP_LEN_OFFSET)
1241            .unwrap_or(0) as usize;
1242        let max_len = self
1243            .config
1244            .shared_region_size
1245            .saturating_sub(SHARED_PAYLOAD_OFFSET as usize);
1246        let read_len = len.min(max_len);
1247
1248        if read_len == 0 {
1249            return Vec::new();
1250        }
1251
1252        let mut buf = vec![0u8; read_len];
1253        let _ = process.read_bytes(vaddr + SHARED_PAYLOAD_OFFSET, &mut buf);
1254        buf
1255    }
1256
1257    fn clear_shared_length(&self) {
1258        let (Some(vaddr), Some(cr3)) = (self.shared_vaddr, self.shared_cr3) else {
1259            return;
1260        };
1261        let process = self.vm.process_memory(cr3);
1262        let _ = process.write_u64(vaddr + SHARED_ACTION_LEN_OFFSET, 0);
1263        let _ = process.write_u64(vaddr + SHARED_RESP_LEN_OFFSET, 0);
1264    }
1265
1266    /// Runs a single step, returning detailed results.
1267    pub fn run_step(&mut self, payload: &[u8]) -> anyhow::Result<NyxStepResult> {
1268        // Write action to shared memory
1269        self.write_action_to_shared_memory(payload)?;
1270
1271        // Run the VM until we get a meaningful exit
1272        let start = Instant::now();
1273        let mut output = Vec::new();
1274        let mut trace_data = Vec::new();
1275        let mut parsed_obs = None;
1276        let mut parsed_rew = None;
1277        let mut done = false;
1278        let exit_kind;
1279        let collect_output =
1280            matches!(
1281                self.config.observation_policy,
1282                NyxObservationPolicy::OutputHash | NyxObservationPolicy::RawOutput
1283            ) || matches!(self.config.reward_policy, NyxRewardPolicy::Pattern { .. })
1284                || matches!(
1285                    self.reward_shaping,
1286                    Some(NyxRewardShaping::EntropyReduction { .. })
1287                );
1288
1289        loop {
1290            let remaining = self
1291                .config
1292                .step_timeout
1293                .checked_sub(start.elapsed())
1294                .unwrap_or(Duration::ZERO);
1295
1296            if remaining.is_zero() {
1297                exit_kind = NyxExitKind::Timeout;
1298                break;
1299            }
1300
1301            let exit = self.vm.run(remaining);
1302            match exit {
1303                ExitReason::ExecDone(code) => {
1304                    exit_kind = NyxExitKind::ExecDone(code);
1305                    done = true;
1306                    break;
1307                }
1308                ExitReason::Timeout => {
1309                    if self.config.debug_mode {
1310                        eprintln!("[NyxVm] Step timeout");
1311                    }
1312                    exit_kind = NyxExitKind::Timeout;
1313                    break;
1314                }
1315                ExitReason::Shutdown => {
1316                    if self.config.debug_mode {
1317                        eprintln!("[NyxVm] VM shutdown during step");
1318                    }
1319                    exit_kind = NyxExitKind::Shutdown;
1320                    done = true;
1321                    break;
1322                }
1323                ExitReason::DebugPrint(msg) => {
1324                    if self.config.debug_mode {
1325                        eprintln!("[NyxVm] Guest: {}", msg);
1326                    }
1327                    // Accumulate debug output
1328                    if collect_output {
1329                        output.extend_from_slice(msg.as_bytes());
1330                    }
1331                    // Continue running
1332                }
1333                ExitReason::Hypercall(r8, r9, r10, r11, r12) => {
1334                    exit_kind = NyxExitKind::Hypercall {
1335                        code: r8,
1336                        arg1: r9,
1337                        arg2: r10,
1338                        arg3: r11,
1339                        arg4: r12,
1340                    };
1341                    // Attempt to parse structured response
1342                    if let Some(obs) = Self::try_parse_u64(r9) {
1343                        parsed_obs = Some(obs);
1344                    }
1345                    if let Some(rew) = Self::try_parse_i64(r10) {
1346                        parsed_rew = Some(rew);
1347                    }
1348                    break;
1349                }
1350                ExitReason::Breakpoint => {
1351                    if self.config.debug_mode {
1352                        eprintln!("[NyxVm] Breakpoint exit during step");
1353                    }
1354                    exit_kind = NyxExitKind::Breakpoint;
1355                    break;
1356                }
1357                _ => {
1358                    // Continue for other exits
1359                }
1360            }
1361        }
1362
1363        // Read shared memory contents (only if needed)
1364        let need_shared_memory = matches!(
1365            self.config.observation_policy,
1366            NyxObservationPolicy::SharedMemory
1367        ) || matches!(
1368            self.config.reward_policy,
1369            NyxRewardPolicy::Pattern { .. }
1370        ) || matches!(
1371            self.reward_shaping,
1372            Some(NyxRewardShaping::EntropyReduction { .. })
1373        ) || self.config.trace.is_some();
1374        let shared_memory = if need_shared_memory {
1375            self.read_shared_memory()
1376        } else {
1377            Vec::new()
1378        };
1379
1380        // Clear shared length to avoid host/guest races on the next step.
1381        self.clear_shared_length();
1382
1383        // Collect trace data if configured
1384        if let Some(trace_cfg) = &self.config.trace
1385            && trace_cfg.shared_region_name.is_some()
1386        {
1387            // Read from trace shared memory region (implementation-specific)
1388            // For now, use main shared memory as fallback
1389            trace_data = shared_memory.clone();
1390            if trace_data.len() > trace_cfg.max_bytes {
1391                trace_data.truncate(trace_cfg.max_bytes);
1392            }
1393        }
1394
1395        Ok(NyxStepResult {
1396            exit_reason: exit_kind,
1397            output,
1398            parsed_obs,
1399            parsed_rew,
1400            done,
1401            trace_data,
1402            shared_memory,
1403        })
1404    }
1405
1406    fn try_parse_u64(val: u64) -> Option<u64> {
1407        // Hypercall args are already u64
1408        Some(val)
1409    }
1410
1411    fn try_parse_i64(val: u64) -> Option<i64> {
1412        Some(val as i64)
1413    }
1414
1415    /// Gets the action payload for the given action index.
1416    fn get_action_payload(&mut self, action: Action) -> anyhow::Result<Cow<'_, [u8]>> {
1417        match &self.config.action_source {
1418            NyxActionSource::Literal(actions) => {
1419                let idx = action as usize;
1420                if idx >= actions.len() {
1421                    return Err(anyhow::anyhow!("Action index out of range"));
1422                }
1423                Ok(Cow::Borrowed(actions[idx].payload.as_slice()))
1424            }
1425            NyxActionSource::Fuzz(fuzz) => {
1426                let state = self
1427                    .fuzz_state
1428                    .as_mut()
1429                    .ok_or_else(|| anyhow::anyhow!("Fuzz state missing"))?;
1430                let idx = action as usize % fuzz.mutators.len();
1431                let mut input = state.current.clone();
1432                let mutator = &fuzz.mutators[idx];
1433                apply_mutator(mutator, &mut input, fuzz, &mut state.rng);
1434                if input.len() < fuzz.min_len {
1435                    input.resize(fuzz.min_len, 0);
1436                }
1437                if input.len() > fuzz.max_len {
1438                    input.truncate(fuzz.max_len);
1439                }
1440                state.current = input.clone();
1441                Ok(Cow::Owned(input))
1442            }
1443        }
1444    }
1445
1446    /// Applies action filtering, returning reject reward if filtered.
1447    fn filter_action(&self, payload: &[u8]) -> Option<i64> {
1448        let filter = self.config.action_filter.as_ref()?;
1449        if payload.is_empty() {
1450            return filter.reject_reward;
1451        }
1452
1453        let (entropy, intrinsic, novelty) = self.compute_filter_metrics(payload, filter);
1454
1455        if let Some(min_entropy) = filter.min_entropy
1456            && entropy < min_entropy
1457        {
1458            return filter.reject_reward;
1459        }
1460        if let Some(max_entropy) = filter.max_entropy
1461            && entropy > max_entropy
1462        {
1463            return filter.reject_reward;
1464        }
1465        if let Some(min_intrinsic) = filter.min_intrinsic_dependence
1466            && intrinsic < min_intrinsic
1467        {
1468            return filter.reject_reward;
1469        }
1470        if let Some(min_novelty) = filter.min_novelty
1471            && filter.novelty_prior.is_some()
1472            && novelty < min_novelty
1473        {
1474            return filter.reject_reward;
1475        }
1476        None
1477    }
1478
1479    fn wrap_action_payload(&self, payload: &[u8]) -> Vec<u8> {
1480        let p = &self.config.protocol;
1481        let mut wrapped = p.action_prefix.clone().into_bytes();
1482        wrapped.extend_from_slice(p.wire_encoding.encode(payload).as_bytes());
1483        wrapped.extend_from_slice(p.action_suffix.as_bytes());
1484        wrapped
1485    }
1486
1487    fn compute_filter_metrics(&self, payload: &[u8], filter: &NyxActionFilter) -> (f64, f64, f64) {
1488        let h_marg = marginal_entropy_bytes(payload);
1489        let h_rate = if filter.max_order == 0 {
1490            h_marg
1491        } else {
1492            entropy_rate_backend(payload, filter.max_order, &self.config.stats_backend)
1493        };
1494
1495        let intrinsic = if h_marg < 1e-9 {
1496            0.0
1497        } else {
1498            ((h_marg - h_rate) / h_marg).clamp(0.0, 1.0)
1499        };
1500
1501        let novelty = if let Some(ref prior) = filter.novelty_prior {
1502            cross_entropy_rate_backend(payload, prior, filter.max_order, &self.config.stats_backend)
1503        } else {
1504            0.0
1505        };
1506
1507        (h_rate, intrinsic, novelty)
1508    }
1509
1510    /// Computes reward from step result.
1511    fn compute_reward(&mut self, result: &NyxStepResult) -> Reward {
1512        let base_reward = match &self.config.reward_policy {
1513            NyxRewardPolicy::FromGuest => result.parsed_rew.unwrap_or(0),
1514            NyxRewardPolicy::Pattern {
1515                pattern,
1516                base_reward,
1517                bonus_reward,
1518            } => {
1519                let text = String::from_utf8_lossy(&result.output);
1520                let shared_text = String::from_utf8_lossy(&result.shared_memory);
1521                if text.contains(pattern) || shared_text.contains(pattern) {
1522                    base_reward + bonus_reward
1523                } else {
1524                    *base_reward
1525                }
1526            }
1527            NyxRewardPolicy::Custom(f) => f(result),
1528        };
1529
1530        let shaping_reward = if let Some(shaping) = self.reward_shaping.clone() {
1531            self.compute_reward_shaping(&shaping, result)
1532        } else {
1533            0
1534        };
1535
1536        let mut reward = base_reward.saturating_add(shaping_reward);
1537
1538        reward = reward.saturating_sub(self.config.step_cost);
1539        let min_reward = self.min_reward();
1540        let max_reward = self.max_reward();
1541        reward.clamp(min_reward, max_reward)
1542    }
1543
1544    fn compute_reward_shaping(
1545        &mut self,
1546        shaping: &NyxRewardShaping,
1547        result: &NyxStepResult,
1548    ) -> Reward {
1549        match shaping {
1550            NyxRewardShaping::EntropyReduction {
1551                max_order,
1552                scale,
1553                crash_bonus,
1554                timeout_bonus,
1555                ..
1556            } => {
1557                let mut base_reward = {
1558                    let data = if result.shared_memory.is_empty() {
1559                        &result.output
1560                    } else {
1561                        &result.shared_memory
1562                    };
1563                    let h_obs = if *max_order == 0 {
1564                        marginal_entropy_bytes(data)
1565                    } else {
1566                        entropy_rate_backend(data, *max_order, &self.config.stats_backend)
1567                    };
1568                    let h_base = self.baseline_entropy.unwrap_or(0.0);
1569                    let er = (h_base - h_obs) * scale;
1570                    er.round() as i64
1571                };
1572
1573                // Add bonuses for interesting behaviors (bugs/crashes)
1574                match &result.exit_reason {
1575                    NyxExitKind::Shutdown | NyxExitKind::Breakpoint => {
1576                        if let Some(bonus) = crash_bonus {
1577                            base_reward = base_reward.saturating_add(*bonus);
1578                        }
1579                    }
1580                    NyxExitKind::Timeout => {
1581                        if let Some(bonus) = timeout_bonus {
1582                            base_reward = base_reward.saturating_add(*bonus);
1583                        }
1584                    }
1585                    _ => {}
1586                }
1587
1588                base_reward
1589            }
1590            NyxRewardShaping::TraceEntropy {
1591                scale, normalize, ..
1592            } => {
1593                let data = &result.trace_data;
1594                let bits = match self.trace_model.as_mut() {
1595                    Some(model) => model.update_and_score(data),
1596                    None => 0.0,
1597                };
1598                let bits = if *normalize && !data.is_empty() {
1599                    bits / data.len() as f64
1600                } else {
1601                    bits
1602                };
1603                (bits * scale).round() as i64
1604            }
1605        }
1606    }
1607
1608    fn mask_observation(&self, value: u64) -> u64 {
1609        let bits = self.config.observation_bits;
1610        if bits >= 64 {
1611            value
1612        } else if bits == 0 {
1613            0
1614        } else {
1615            value & ((1u64 << bits) - 1)
1616        }
1617    }
1618
1619    fn build_observation_stream(&self, result: &NyxStepResult) -> Vec<PerceptVal> {
1620        let mut observations = match self.config.observation_policy {
1621            NyxObservationPolicy::FromGuest => {
1622                if let Some(obs) = result.parsed_obs {
1623                    vec![self.mask_observation(obs)]
1624                } else {
1625                    vec![self.hash_observation(&result.shared_memory)]
1626                }
1627            }
1628            NyxObservationPolicy::OutputHash => {
1629                vec![self.hash_observation(&result.output)]
1630            }
1631            NyxObservationPolicy::RawOutput => {
1632                result.output.iter().map(|b| *b as PerceptVal).collect()
1633            }
1634            NyxObservationPolicy::SharedMemory => result
1635                .shared_memory
1636                .iter()
1637                .map(|b| *b as PerceptVal)
1638                .collect(),
1639        };
1640
1641        if observations.is_empty() {
1642            observations.push(0);
1643        }
1644
1645        self.normalize_observation_stream(&mut observations);
1646        observations
1647    }
1648
1649    fn hash_observation(&self, data: &[u8]) -> PerceptVal {
1650        let h = robust_hash_bytes(data);
1651        self.mask_observation(h)
1652    }
1653
1654    fn normalize_observation_stream(&self, observations: &mut Vec<PerceptVal>) {
1655        let mask = if self.config.observation_bits >= 64 {
1656            u64::MAX
1657        } else if self.config.observation_bits == 0 {
1658            0
1659        } else {
1660            (1u64 << self.config.observation_bits) - 1
1661        };
1662
1663        for obs in observations.iter_mut() {
1664            *obs &= mask;
1665        }
1666
1667        let target = self.config.observation_stream_len;
1668        if target == 0 {
1669            return;
1670        }
1671
1672        if observations.len() > target {
1673            match self.config.observation_stream_mode {
1674                NyxObservationStreamMode::Truncate | NyxObservationStreamMode::PadTruncate => {
1675                    observations.truncate(target);
1676                }
1677                NyxObservationStreamMode::Pad => {}
1678            }
1679        } else if observations.len() < target {
1680            match self.config.observation_stream_mode {
1681                NyxObservationStreamMode::Pad | NyxObservationStreamMode::PadTruncate => {
1682                    let pad = self.config.observation_pad_byte as PerceptVal;
1683                    observations.resize(target, pad);
1684                }
1685                NyxObservationStreamMode::Truncate => {}
1686            }
1687        }
1688    }
1689
1690    fn action_count(&self) -> usize {
1691        match &self.config.action_source {
1692            NyxActionSource::Literal(actions) => actions.len(),
1693            NyxActionSource::Fuzz(fuzz) => fuzz.mutators.len(),
1694        }
1695    }
1696
1697    /// Direct access to the underlying NyxVM for advanced use cases.
1698    pub fn vm(&self) -> &NyxVM {
1699        &self.vm
1700    }
1701
1702    /// Mutable access to the underlying NyxVM.
1703    pub fn vm_mut(&mut self) -> &mut NyxVM {
1704        &mut self.vm
1705    }
1706
1707    /// Takes a new snapshot at the current state.
1708    pub fn take_snapshot(&mut self) -> Arc<NyxSnapshot> {
1709        self.vm.take_snapshot()
1710    }
1711
1712    /// Applies a specific snapshot.
1713    pub fn apply_snapshot(&mut self, snapshot: &Arc<NyxSnapshot>) {
1714        self.vm.apply_snapshot(snapshot);
1715    }
1716
1717    /// Resets trace model.
1718    pub fn reset_trace_model(&mut self) {
1719        if let Some(model) = &mut self.trace_model {
1720            model.reset();
1721        }
1722    }
1723
1724    /// Logs crashes and interesting behaviors to file.
1725    fn log_crash(&self, action_payload: &[u8], result: &NyxStepResult, reward: i64) {
1726        let Some(log_path) = &self.config.crash_log else {
1727            return;
1728        };
1729
1730        // Only log interesting exits
1731        let is_interesting = matches!(
1732            result.exit_reason,
1733            NyxExitKind::Shutdown | NyxExitKind::Breakpoint | NyxExitKind::Timeout
1734        );
1735
1736        if !is_interesting {
1737            return;
1738        }
1739
1740        let log_entry = serde_json::json!({
1741            "timestamp": std::time::SystemTime::now()
1742                .duration_since(std::time::UNIX_EPOCH)
1743                .unwrap_or_default()
1744                .as_secs(),
1745            "exit_reason": format!("{:?}", result.exit_reason),
1746            "action_payload": hex_encode(action_payload),
1747            "action_payload_str": String::from_utf8_lossy(action_payload),
1748            "output": String::from_utf8_lossy(&result.output),
1749            "shared_memory": hex_encode(&result.shared_memory),
1750            "reward": reward,
1751            "parsed_obs": result.parsed_obs,
1752            "parsed_rew": result.parsed_rew,
1753        });
1754
1755        // Append to JSONL file
1756        if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(log_path)
1757            && let Ok(json_str) = serde_json::to_string(&log_entry)
1758        {
1759            let _ = writeln!(file, "{}", json_str);
1760        }
1761    }
1762}
1763
1764// ============================================================================
1765// Environment Trait Implementation
1766// ============================================================================
1767
1768impl Environment for NyxVmEnvironment {
1769    fn perform_action(&mut self, action: Action) {
1770        if self.needs_reset
1771            && let Err(e) = self.reset()
1772            && self.config.debug_mode
1773        {
1774            eprintln!("[NyxVm] Reset failed: {}", e);
1775        }
1776
1777        let payload = match self.get_action_payload(action) {
1778            Ok(payload) => payload.into_owned(),
1779            Err(e) => {
1780                if self.config.debug_mode {
1781                    eprintln!("[NyxVm] Invalid action: {}", e);
1782                }
1783                self.obs = 0;
1784                self.rew = self.min_reward();
1785                self.obs_stream.clear();
1786                self.obs_stream.push(0);
1787                self.step_in_episode = (self.step_in_episode + 1) % self.config.episode_steps;
1788                if self.step_in_episode == 0 {
1789                    self.needs_reset = true;
1790                }
1791                return;
1792            }
1793        };
1794
1795        // Check action filter
1796        if let Some(reject_reward) = self.filter_action(&payload) {
1797            self.obs = 0;
1798            self.rew = reject_reward.clamp(self.min_reward(), self.max_reward());
1799            self.obs_stream.clear();
1800            self.obs_stream.push(0);
1801            self.step_in_episode = (self.step_in_episode + 1) % self.config.episode_steps;
1802            if self.step_in_episode == 0 {
1803                self.needs_reset = true;
1804            }
1805            return;
1806        }
1807
1808        // Run the step
1809        let wrapped_payload = self.wrap_action_payload(&payload);
1810        let result = match self.run_step(&wrapped_payload) {
1811            Ok(result) => result,
1812            Err(e) => {
1813                if self.config.debug_mode {
1814                    eprintln!("[NyxVm] Step failed: {}", e);
1815                }
1816                self.obs = 0;
1817                self.rew = self.min_reward();
1818                self.obs_stream.clear();
1819                self.obs_stream.push(0);
1820                self.step_in_episode = (self.step_in_episode + 1) % self.config.episode_steps;
1821                if self.step_in_episode == 0 {
1822                    self.needs_reset = true;
1823                }
1824                return;
1825            }
1826        };
1827
1828        // Process results
1829        self.obs_stream = self.build_observation_stream(&result);
1830        self.obs = self.obs_stream.first().copied().unwrap_or(0);
1831        self.rew = self.compute_reward(&result);
1832
1833        // Log crashes and interesting behaviors
1834        self.log_crash(&payload, &result, self.rew);
1835
1836        if self.config.debug_mode {
1837            eprintln!(
1838                "[NyxVm] Action={} Obs={} Rew={} Done={:?} Exit={:?}",
1839                action, self.obs, self.rew, result.done, result.exit_reason
1840            );
1841        }
1842
1843        self.step_in_episode = (self.step_in_episode + 1) % self.config.episode_steps;
1844        if self.step_in_episode == 0 || result.done {
1845            self.needs_reset = true;
1846        }
1847    }
1848
1849    fn get_observation(&self) -> PerceptVal {
1850        self.obs
1851    }
1852
1853    fn drain_observations(&mut self) -> Vec<PerceptVal> {
1854        if self.obs_stream.is_empty() {
1855            vec![self.obs]
1856        } else {
1857            std::mem::take(&mut self.obs_stream)
1858        }
1859    }
1860
1861    fn get_reward(&self) -> Reward {
1862        self.rew
1863    }
1864
1865    fn is_finished(&self) -> bool {
1866        false
1867    }
1868
1869    fn get_observation_bits(&self) -> usize {
1870        self.config.observation_bits
1871    }
1872
1873    fn get_reward_bits(&self) -> usize {
1874        self.config.reward_bits
1875    }
1876
1877    fn get_action_bits(&self) -> usize {
1878        let n = self.action_count();
1879        if n <= 1 {
1880            return 1;
1881        }
1882        (n as f64).log2().ceil() as usize
1883    }
1884
1885    fn get_num_actions(&self) -> usize {
1886        self.action_count()
1887    }
1888
1889    fn max_reward(&self) -> Reward {
1890        let bits = self.config.reward_bits;
1891        if bits >= 64 {
1892            i64::MAX
1893        } else if bits == 0 {
1894            0
1895        } else {
1896            (1i64 << (bits - 1)) - 1
1897        }
1898    }
1899
1900    fn min_reward(&self) -> Reward {
1901        let bits = self.config.reward_bits;
1902        if bits >= 64 {
1903            i64::MIN
1904        } else if bits == 0 {
1905            0
1906        } else {
1907            -(1i64 << (bits - 1))
1908        }
1909    }
1910}
1911
1912// ============================================================================
1913// Helper Functions
1914// ============================================================================
1915
1916fn robust_hash_bytes(data: &[u8]) -> u64 {
1917    let mut h = 0u64;
1918    for &b in data {
1919        h = h.rotate_left(7) ^ (b as u64);
1920    }
1921    h
1922}
1923
1924fn apply_mutator(
1925    mutator: &FuzzMutator,
1926    input: &mut Vec<u8>,
1927    fuzz: &NyxFuzzConfig,
1928    rng: &mut RandomGenerator,
1929) {
1930    match mutator {
1931        FuzzMutator::FlipBit => {
1932            if input.is_empty() {
1933                input.push(0);
1934            }
1935            let idx = rng.gen_range(input.len());
1936            let bit = rng.gen_range(8);
1937            input[idx] ^= 1u8 << bit;
1938        }
1939        FuzzMutator::FlipByte => {
1940            if input.is_empty() {
1941                input.push(0);
1942            }
1943            let idx = rng.gen_range(input.len());
1944            input[idx] ^= rng.next_u64() as u8;
1945        }
1946        FuzzMutator::InsertByte => {
1947            let idx = if input.is_empty() {
1948                0
1949            } else {
1950                rng.gen_range(input.len() + 1)
1951            };
1952            let byte = if !fuzz.dictionary.is_empty() {
1953                let d = rng.gen_range(fuzz.dictionary.len());
1954                let entry = &fuzz.dictionary[d];
1955                if entry.is_empty() {
1956                    0
1957                } else {
1958                    entry[rng.gen_range(entry.len())]
1959                }
1960            } else {
1961                rng.next_u64() as u8
1962            };
1963            input.insert(idx, byte);
1964        }
1965        FuzzMutator::DeleteByte => {
1966            if input.len() > 1 {
1967                let idx = rng.gen_range(input.len());
1968                input.remove(idx);
1969            }
1970        }
1971        FuzzMutator::SpliceSeed => {
1972            if fuzz.seeds.is_empty() {
1973                return;
1974            }
1975            let seed = &fuzz.seeds[rng.gen_range(fuzz.seeds.len())];
1976            if input.is_empty() {
1977                input.extend_from_slice(seed);
1978            } else if !seed.is_empty() {
1979                let cut = rng.gen_range(input.len());
1980                let seed_cut = rng.gen_range(seed.len());
1981                let mut out = Vec::new();
1982                out.extend_from_slice(&input[..cut]);
1983                out.extend_from_slice(&seed[seed_cut..]);
1984                *input = out;
1985            }
1986        }
1987        FuzzMutator::ResetSeed => {
1988            if fuzz.seeds.is_empty() {
1989                return;
1990            }
1991            *input = fuzz.seeds[rng.gen_range(fuzz.seeds.len())].clone();
1992        }
1993        FuzzMutator::Havoc => {
1994            let flips = 1 + rng.gen_range(8);
1995            for _ in 0..flips {
1996                if input.is_empty() {
1997                    input.push(0);
1998                }
1999                let idx = rng.gen_range(input.len());
2000                input[idx] ^= rng.next_u64() as u8;
2001            }
2002        }
2003    }
2004}
2005
2006// ============================================================================
2007// Tests
2008// ============================================================================
2009
2010#[cfg(test)]
2011mod tests {
2012    use super::*;
2013
2014    #[test]
2015    fn test_hex_encoding() {
2016        let data = b"hello";
2017        let encoded = hex_encode(data);
2018        assert_eq!(encoded, "68656c6c6f");
2019        let decoded = hex_decode(&encoded).unwrap();
2020        assert_eq!(decoded, data);
2021    }
2022
2023    #[test]
2024    fn test_robust_hash() {
2025        let data1 = b"test data";
2026        let data2 = b"test data";
2027        let data3 = b"different";
2028
2029        assert_eq!(robust_hash_bytes(data1), robust_hash_bytes(data2));
2030        assert_ne!(robust_hash_bytes(data1), robust_hash_bytes(data3));
2031    }
2032
2033    #[test]
2034    fn test_payload_encoding() {
2035        let utf8 = PayloadEncoding::Utf8;
2036        let hex = PayloadEncoding::Hex;
2037
2038        let data = b"test";
2039        assert_eq!(utf8.encode(data), "test");
2040        assert_eq!(hex.encode(data), "74657374");
2041
2042        assert_eq!(utf8.decode("test").unwrap(), data);
2043        assert_eq!(hex.decode("74657374").unwrap(), data);
2044    }
2045
2046    #[test]
2047    fn trace_model_supports_predictor_backed_backends() {
2048        let backends = vec![
2049            RateBackend::Match {
2050                hash_bits: 20,
2051                min_len: 4,
2052                max_len: 255,
2053                base_mix: 0.02,
2054                confidence_scale: 1.0,
2055            },
2056            RateBackend::SparseMatch {
2057                hash_bits: 19,
2058                min_len: 3,
2059                max_len: 64,
2060                gap_min: 1,
2061                gap_max: 2,
2062                base_mix: 0.05,
2063                confidence_scale: 1.0,
2064            },
2065            RateBackend::Ppmd {
2066                order: 8,
2067                memory_mb: 8,
2068            },
2069            RateBackend::Calibrated {
2070                spec: Arc::new(crate::CalibratedSpec {
2071                    base: RateBackend::Ctw { depth: 8 },
2072                    context: crate::CalibrationContextKind::Text,
2073                    bins: 33,
2074                    learning_rate: 0.02,
2075                    bias_clip: 4.0,
2076                }),
2077            },
2078            RateBackend::Particle {
2079                spec: Arc::new(crate::ParticleSpec {
2080                    num_particles: 4,
2081                    num_cells: 4,
2082                    cell_dim: 8,
2083                    ..crate::ParticleSpec::default()
2084                }),
2085            },
2086            RateBackend::Mixture {
2087                spec: Arc::new(crate::MixtureSpec::new(
2088                    crate::MixtureKind::Bayes,
2089                    vec![crate::MixtureExpertSpec {
2090                        name: Some("ctw".to_string()),
2091                        log_prior: 0.0,
2092                        max_order: -1,
2093                        backend: RateBackend::Ctw { depth: 8 },
2094                    }],
2095                )),
2096            },
2097        ];
2098
2099        for backend in backends {
2100            let mut model = TraceModel::new(&backend, 4);
2101            let bits = model.update_and_score(b"trace payload");
2102            assert!(bits.is_finite() && bits >= 0.0, "bits={bits}");
2103            model.reset();
2104            let bits_after_reset = model.update_and_score(b"trace payload");
2105            assert!(
2106                bits_after_reset.is_finite() && bits_after_reset >= 0.0,
2107                "bits_after_reset={bits_after_reset}"
2108            );
2109        }
2110    }
2111}