infotheory/aixi/
vm_nyx.rs

1//! High-performance VM-backed AIXI environment using nyx-lite (Firecracker).
2//!
3//! This module provides a VM environment implementation built on top of nyx-lite,
4//! enabling high-frequency snapshot-based resets for fast experimentation (hardware and
5//! guest behavior dependent).
6//!
7//! ## Architecture
8//!
9//! The environment uses Firecracker's KVM-based microVM with nyx-lite's incremental
10//! snapshot and reset capabilities. Communication with the guest occurs via:
11//!
12//! 1. **Shared Memory**: Zero-copy data transfer between host and guest
13//! 2. **Hypercalls**: Control plane communication (snapshot, done, etc.)
14//! 3. **Serial PTY**: Optional console I/O for simpler protocols
15//!
16//! ## Design Principles
17//!
18//! - **Universal**: Not biased towards any specific use case (fuzzing, etc.)
19//! - **High Performance**: Leverages incremental snapshots and dirty page tracking
20//! - **Configurable**: Pluggable reward policies, action sources, observation modes
21//! - **Information-Theoretic**: Built-in support for entropy-based metrics
22
23use crate::aixi::common::{Action, PerceptVal, RandomGenerator, Reward};
24use crate::aixi::environment::Environment;
25use crate::mixture::OnlineBytePredictor;
26use crate::{
27    RateBackend, cross_entropy_rate_backend, entropy_rate_backend, marginal_entropy_bytes,
28};
29use crate::zpaq_rate::ZpaqRateModel;
30use rosaplus::RosaPlus;
31use rwkvzip::Compressor;
32use rwkvzip::coders::softmax_pdf_inplace;
33use serde_json::Value;
34use std::borrow::Cow;
35use std::fs::OpenOptions;
36use std::io::Write;
37use std::path::Path;
38use std::sync::Arc;
39use std::time::{Duration, Instant};
40
41// Re-export nyx-lite types for external use
42pub use nyx_lite::mem::SharedMemoryRegion;
43pub use nyx_lite::snapshot::NyxSnapshot;
44pub use nyx_lite::{ExitReason, NyxVM, SharedMemoryPolicy};
45
46// ============================================================================
47// Encoding Types
48// ============================================================================
49
50/// Payload encoding for wire protocol.
51#[derive(Clone, Copy, Debug, Eq, PartialEq)]
52pub enum PayloadEncoding {
53    Utf8,
54    Hex,
55}
56
57impl PayloadEncoding {
58    pub fn from_str(s: &str) -> Option<Self> {
59        match s {
60            "utf8" | "text" => Some(Self::Utf8),
61            "hex" => Some(Self::Hex),
62            _ => None,
63        }
64    }
65
66    pub fn decode(self, s: &str) -> anyhow::Result<Vec<u8>> {
67        match self {
68            Self::Utf8 => Ok(s.as_bytes().to_vec()),
69            Self::Hex => hex_decode(s),
70        }
71    }
72
73    pub fn encode(self, bytes: &[u8]) -> String {
74        match self {
75            Self::Utf8 => String::from_utf8_lossy(bytes).to_string(),
76            Self::Hex => hex_encode(bytes),
77        }
78    }
79}
80
81fn hex_decode(s: &str) -> anyhow::Result<Vec<u8>> {
82    let mut out = Vec::with_capacity(s.len() / 2);
83    let mut buf = 0u8;
84    let mut high = true;
85    for c in s.bytes() {
86        let v = match c {
87            b'0'..=b'9' => c - b'0',
88            b'a'..=b'f' => c - b'a' + 10,
89            b'A'..=b'F' => c - b'A' + 10,
90            b' ' | b'\n' | b'\r' | b'\t' => continue,
91            _ => return Err(anyhow::anyhow!("invalid hex byte: {}", c as char)),
92        };
93        if high {
94            buf = v << 4;
95            high = false;
96        } else {
97            buf |= v;
98            out.push(buf);
99            high = true;
100        }
101    }
102    if !high {
103        return Err(anyhow::anyhow!("hex string has odd length"));
104    }
105    Ok(out)
106}
107
108fn resolve_relative_path(base: &Path, path: &str) -> String {
109    let p = Path::new(path);
110    if p.is_absolute() {
111        path.to_string()
112    } else {
113        base.join(p).to_string_lossy().to_string()
114    }
115}
116
117fn rewrite_firecracker_config_paths(config_path: &str, raw_json: &str) -> anyhow::Result<String> {
118    let base_dir = Path::new(config_path)
119        .parent()
120        .unwrap_or_else(|| Path::new("."));
121    let mut v: Value = serde_json::from_str(raw_json)?;
122
123    if let Some(boot) = v.get_mut("boot-source") {
124        if let Some(path_val) = boot.get_mut("kernel_image_path") {
125            if let Some(path_str) = path_val.as_str() {
126                let resolved = resolve_relative_path(base_dir, path_str);
127                *path_val = Value::String(resolved);
128            }
129        }
130        if let Some(path_val) = boot.get_mut("initrd_path") {
131            if let Some(path_str) = path_val.as_str() {
132                let resolved = resolve_relative_path(base_dir, path_str);
133                *path_val = Value::String(resolved);
134            }
135        }
136    }
137
138    if let Some(drives) = v.get_mut("drives").and_then(|d| d.as_array_mut()) {
139        for drive in drives {
140            if let Some(path_val) = drive.get_mut("path_on_host") {
141                if let Some(path_str) = path_val.as_str() {
142                    let resolved = resolve_relative_path(base_dir, path_str);
143                    *path_val = Value::String(resolved);
144                }
145            }
146        }
147    }
148
149    Ok(serde_json::to_string(&v)?)
150}
151
152fn hex_encode(bytes: &[u8]) -> String {
153    let mut s = String::with_capacity(bytes.len() * 2);
154    for b in bytes {
155        s.push(hex_digit(b >> 4));
156        s.push(hex_digit(b & 0x0F));
157    }
158    s
159}
160
161fn hex_digit(v: u8) -> char {
162    match v {
163        0..=9 => (b'0' + v) as char,
164        _ => (b'a' + (v - 10)) as char,
165    }
166}
167
168// ============================================================================
169// Guest Communication Protocol
170// ============================================================================
171
172/// Hypercall identifiers (must match guest implementation).
173/// These are exported for use by custom guest programs.
174#[allow(dead_code)]
175pub const HYPERCALL_EXECDONE: u64 = 0x656e6f6463657865; // "execdone"
176#[allow(dead_code)]
177pub const HYPERCALL_SNAPSHOT: u64 = 0x746f687370616e73; // "snapshot"
178#[allow(dead_code)]
179pub const HYPERCALL_NYX_LITE: u64 = 0x6574696c2d78796e; // "nyx-lite"
180#[allow(dead_code)]
181pub const HYPERCALL_SHAREMEM: u64 = 0x6d656d6572616873; // "sharemem"
182#[allow(dead_code)]
183pub const HYPERCALL_DBGPRINT: u64 = 0x746e697270676264; // "dbgprint"
184
185const SHARED_ACTION_LEN_OFFSET: u64 = 0;
186const SHARED_RESP_LEN_OFFSET: u64 = 8;
187const SHARED_PAYLOAD_OFFSET: u64 = 16;
188
189/// Protocol configuration for structured communication.
190#[derive(Clone, Debug)]
191pub struct NyxProtocolConfig {
192    /// Prefix for action messages.
193    pub action_prefix: String,
194    /// Suffix for action messages.
195    pub action_suffix: String,
196    /// Prefix for observation responses.
197    pub obs_prefix: String,
198    /// Prefix for reward responses.
199    pub rew_prefix: String,
200    /// Prefix for done indicator.
201    pub done_prefix: String,
202    /// Prefix for data payloads.
203    pub data_prefix: String,
204    /// Wire encoding for payloads.
205    pub wire_encoding: PayloadEncoding,
206}
207
208impl Default for NyxProtocolConfig {
209    fn default() -> Self {
210        Self {
211            action_prefix: "ACT ".to_string(),
212            action_suffix: "\n".to_string(),
213            obs_prefix: "OBS ".to_string(),
214            rew_prefix: "REW ".to_string(),
215            done_prefix: "DONE ".to_string(),
216            data_prefix: "DATA ".to_string(),
217            wire_encoding: PayloadEncoding::Hex,
218        }
219    }
220}
221
222// ============================================================================
223// Action Configuration
224// ============================================================================
225
226/// A single action specification.
227#[derive(Clone, Debug)]
228pub struct NyxActionSpec {
229    /// Optional human-readable name.
230    pub name: Option<String>,
231    /// Raw payload bytes to send.
232    pub payload: Vec<u8>,
233}
234
235/// Fuzzing mutator types.
236#[derive(Clone, Debug)]
237pub enum FuzzMutator {
238    FlipBit,
239    FlipByte,
240    InsertByte,
241    DeleteByte,
242    SpliceSeed,
243    ResetSeed,
244    Havoc,
245}
246
247/// Fuzzing configuration for action generation.
248#[derive(Clone, Debug)]
249pub struct NyxFuzzConfig {
250    pub seeds: Vec<Vec<u8>>,
251    pub mutators: Vec<FuzzMutator>,
252    pub min_len: usize,
253    pub max_len: usize,
254    pub dictionary: Vec<Vec<u8>>,
255    pub rng_seed: u64,
256}
257
258/// Source of actions for the environment.
259#[derive(Clone, Debug)]
260pub enum NyxActionSource {
261    /// Fixed set of action payloads.
262    Literal(Vec<NyxActionSpec>),
263    /// Mutation-based action generation.
264    Fuzz(NyxFuzzConfig),
265}
266
267// ============================================================================
268// Observation Configuration
269// ============================================================================
270
271/// How observations are derived from guest output.
272#[derive(Clone, Copy, Debug)]
273pub enum NyxObservationPolicy {
274    /// Parse structured OBS/REW/DONE messages from guest.
275    FromGuest,
276    /// Hash raw output to derive observation.
277    OutputHash,
278    /// Use raw output bytes as observation stream.
279    RawOutput,
280    /// Use shared memory contents as observation.
281    SharedMemory,
282}
283
284/// Stream normalization mode.
285#[derive(Clone, Copy, Debug)]
286pub enum NyxObservationStreamMode {
287    /// Pad short streams, truncate long ones.
288    PadTruncate,
289    /// Only pad short streams.
290    Pad,
291    /// Only truncate long streams.
292    Truncate,
293}
294
295// ============================================================================
296// Reward Configuration
297// ============================================================================
298
299/// How rewards are computed.
300#[derive(Clone)]
301pub enum NyxRewardPolicy {
302    /// Parse reward from guest response.
303    FromGuest,
304    /// Pattern matching on output.
305    Pattern {
306        pattern: String,
307        base_reward: i64,
308        bonus_reward: i64,
309    },
310    /// Custom reward function (callback-based).
311    Custom(Arc<dyn Fn(&NyxStepResult) -> Reward + Send + Sync>),
312}
313
314/// Optional reward shaping (additive to base reward).
315#[derive(Clone, Debug)]
316pub enum NyxRewardShaping {
317    /// Entropy reduction vs baseline.
318    EntropyReduction {
319        baseline_bytes: Vec<u8>,
320        max_order: i64,
321        scale: f64,
322        crash_bonus: Option<i64>,
323        timeout_bonus: Option<i64>,
324    },
325    /// Entropy of trace data (online learning).
326    TraceEntropy {
327        max_order: i64,
328        scale: f64,
329        normalize: bool,
330    },
331}
332
333impl std::fmt::Debug for NyxRewardPolicy {
334    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335        match self {
336            Self::FromGuest => write!(f, "FromGuest"),
337            Self::Pattern {
338                pattern,
339                base_reward,
340                bonus_reward,
341            } => f
342                .debug_struct("Pattern")
343                .field("pattern", pattern)
344                .field("base_reward", base_reward)
345                .field("bonus_reward", bonus_reward)
346                .finish(),
347            Self::Custom(_) => write!(f, "Custom(<fn>)"),
348        }
349    }
350}
351
352// ============================================================================
353// Action Filtering
354// ============================================================================
355
356/// Information-theoretic action filtering.
357#[derive(Clone, Debug)]
358pub struct NyxActionFilter {
359    /// Minimum entropy threshold.
360    pub min_entropy: Option<f64>,
361    /// Maximum entropy threshold.
362    pub max_entropy: Option<f64>,
363    /// Minimum intrinsic dependence.
364    pub min_intrinsic_dependence: Option<f64>,
365    /// Minimum novelty (cross-entropy vs prior).
366    pub min_novelty: Option<f64>,
367    /// Prior corpus for novelty computation.
368    pub novelty_prior: Option<Vec<u8>>,
369    /// Max order for entropy estimation.
370    pub max_order: i64,
371    /// Reward to assign when action is rejected.
372    pub reject_reward: Option<i64>,
373}
374
375// ============================================================================
376// Trace Configuration
377// ============================================================================
378
379/// Configuration for trace collection and analysis.
380#[derive(Clone, Debug)]
381pub struct NyxTraceConfig {
382    /// Shared memory region name for trace data.
383    pub shared_region_name: Option<String>,
384    /// Maximum bytes to collect per step.
385    pub max_bytes: usize,
386    /// Reset trace model on episode boundary.
387    pub reset_on_episode: bool,
388}
389
390// ============================================================================
391// Main Configuration
392// ============================================================================
393
394/// Complete configuration for the nyx-lite VM environment.
395#[derive(Clone)]
396pub struct NyxVmConfig {
397    /// Path to Firecracker JSON config.
398    pub firecracker_config: String,
399    /// Instance ID for the VM.
400    pub instance_id: String,
401
402    // Shared memory configuration
403    /// Name of the shared memory region for communication.
404    pub shared_region_name: String,
405    /// Size of the shared memory region.
406    pub shared_region_size: usize,
407    /// Shared memory policy (snapshot vs preserve).
408    pub shared_memory_policy: SharedMemoryPolicy,
409
410    // Timing configuration
411    /// Timeout for each step.
412    pub step_timeout: Duration,
413    /// Timeout for initial boot.
414    pub boot_timeout: Duration,
415
416    // Episode configuration
417    /// Number of steps per episode.
418    pub episode_steps: usize,
419    /// Cost subtracted from reward each step.
420    pub step_cost: i64,
421
422    // Observation configuration
423    /// Observation derivation policy.
424    pub observation_policy: NyxObservationPolicy,
425    /// Bits per observation symbol.
426    pub observation_bits: usize,
427    /// Number of observation symbols per action.
428    pub observation_stream_len: usize,
429    /// Stream normalization mode.
430    pub observation_stream_mode: NyxObservationStreamMode,
431    /// Padding byte for short streams.
432    pub observation_pad_byte: u8,
433
434    // Reward configuration
435    /// Bits for reward encoding.
436    pub reward_bits: usize,
437    /// Reward computation policy.
438    pub reward_policy: NyxRewardPolicy,
439    /// Optional reward shaping (additive; non-canonical).
440    pub reward_shaping: Option<NyxRewardShaping>,
441
442    // Action configuration
443    /// Source of actions.
444    pub action_source: NyxActionSource,
445    /// Optional action filter.
446    pub action_filter: Option<NyxActionFilter>,
447
448    // Protocol configuration
449    /// Wire protocol for structured communication.
450    pub protocol: NyxProtocolConfig,
451
452    // Statistics backend
453    /// Backend for entropy estimation.
454    pub stats_backend: RateBackend,
455
456    // Trace configuration
457    /// Optional trace collection.
458    pub trace: Option<NyxTraceConfig>,
459
460    // Debug mode
461    pub debug_mode: bool,
462
463    // Crash logging
464    /// Path to log crashes/interesting behaviors (JSONL format).
465    pub crash_log: Option<String>,
466
467}
468
469impl Default for NyxVmConfig {
470    fn default() -> Self {
471        Self {
472            firecracker_config: String::new(),
473            instance_id: "aixi-nyx".to_string(),
474            shared_region_name: "shared".to_string(),
475            shared_region_size: 4096,
476            shared_memory_policy: SharedMemoryPolicy::Snapshot,
477            step_timeout: Duration::from_millis(100),
478            boot_timeout: Duration::from_secs(30),
479            episode_steps: 100,
480            step_cost: 0,
481            observation_policy: NyxObservationPolicy::SharedMemory,
482            observation_bits: 8,
483            observation_stream_len: 64,
484            observation_stream_mode: NyxObservationStreamMode::PadTruncate,
485            observation_pad_byte: 0,
486            reward_bits: 8,
487            reward_policy: NyxRewardPolicy::FromGuest,
488            reward_shaping: None,
489            action_source: NyxActionSource::Literal(vec![]),
490            action_filter: None,
491            protocol: NyxProtocolConfig::default(),
492            stats_backend: RateBackend::default(),
493            trace: None,
494            debug_mode: false,
495            crash_log: None,
496        }
497    }
498}
499
500// ============================================================================
501// Step Result
502// ============================================================================
503
504/// Result of a single environment step.
505#[derive(Clone, Debug)]
506pub struct NyxStepResult {
507    /// Exit reason from the VM.
508    pub exit_reason: NyxExitKind,
509    /// Raw output data from guest.
510    pub output: Vec<u8>,
511    /// Parsed observation (if any).
512    pub parsed_obs: Option<u64>,
513    /// Parsed reward (if any).
514    pub parsed_rew: Option<i64>,
515    /// Done flag.
516    pub done: bool,
517    /// Trace data (if collected).
518    pub trace_data: Vec<u8>,
519    /// Shared memory contents snapshot.
520    pub shared_memory: Vec<u8>,
521}
522
523/// Simplified exit reason categories.
524#[derive(Clone, Debug)]
525pub enum NyxExitKind {
526    ExecDone(u64),
527    Timeout,
528    Shutdown,
529    Hypercall {
530        code: u64,
531        arg1: u64,
532        arg2: u64,
533        arg3: u64,
534        arg4: u64,
535    },
536    DebugPrint(String),
537    Breakpoint,
538    Other(String),
539}
540
541impl From<ExitReason> for NyxExitKind {
542    fn from(reason: ExitReason) -> Self {
543        match reason {
544            ExitReason::ExecDone(code) => Self::ExecDone(code),
545            ExitReason::Timeout => Self::Timeout,
546            ExitReason::Shutdown => Self::Shutdown,
547            ExitReason::Hypercall(r8, r9, r10, r11, r12) => Self::Hypercall {
548                code: r8,
549                arg1: r9,
550                arg2: r10,
551                arg3: r11,
552                arg4: r12,
553            },
554            ExitReason::DebugPrint(s) => Self::DebugPrint(s),
555            ExitReason::Breakpoint => Self::Breakpoint,
556            ExitReason::RequestSnapshot => Self::Other("RequestSnapshot".to_string()),
557            ExitReason::SharedMem(name, _, _) => Self::Other(format!("SharedMem({})", name)),
558            ExitReason::SingleStep => Self::Other("SingleStep".to_string()),
559            ExitReason::Interrupted => Self::Other("Interrupted".to_string()),
560            ExitReason::HWBreakpoint(n) => Self::Other(format!("HWBreakpoint({})", n)),
561            ExitReason::BadMemoryAccess(_) => Self::Other("BadMemoryAccess".to_string()),
562        }
563    }
564}
565
566// ============================================================================
567// Trace Model
568// ============================================================================
569
570/// Predictive model for trace-based reward computation.
571enum TraceModel {
572    Rosa {
573        model: RosaPlus,
574        max_order: i64,
575    },
576    Ctw {
577        tree: crate::ctw::ContextTree,
578    },
579    FacCtw {
580        tree: crate::ctw::FacContextTree,
581        bits_per_symbol: usize,
582    },
583    Rwkv7 {
584        compressor: Compressor,
585        primed: bool,
586    },
587    Zpaq {
588        model: ZpaqRateModel,
589    },
590    Mixture {
591        backend: RateBackend,
592        model: crate::mixture::RateBackendPredictor,
593    },
594}
595
596impl TraceModel {
597    fn new(backend: &RateBackend, max_order: i64) -> Self {
598        match backend {
599            RateBackend::RosaPlus => {
600                let mut model = RosaPlus::new(max_order, false, 0, 42);
601                model.build_lm_full_bytes_no_finalize_endpos();
602                TraceModel::Rosa { model, max_order }
603            }
604            RateBackend::Rwkv7 { model } => {
605                let compressor = Compressor::new_from_model(model.clone());
606                TraceModel::Rwkv7 {
607                    compressor,
608                    primed: false,
609                }
610            }
611            RateBackend::Zpaq { method } => TraceModel::Zpaq {
612                model: ZpaqRateModel::new(method.clone(), 2f64.powi(-24)),
613            },
614            RateBackend::Mixture { spec } => {
615                let backend = RateBackend::Mixture { spec: spec.clone() };
616                let model = crate::mixture::RateBackendPredictor::from_backend(
617                    backend.clone(),
618                    -1,
619                    2f64.powi(-24),
620                );
621                TraceModel::Mixture { backend, model }
622            }
623            RateBackend::Ctw { depth } => TraceModel::Ctw {
624                tree: crate::ctw::ContextTree::new(*depth),
625            },
626            RateBackend::FacCtw {
627                base_depth,
628                num_percept_bits: _,
629                encoding_bits,
630            } => {
631                let bits_per_symbol = (*encoding_bits).min(8).max(1);
632                TraceModel::FacCtw {
633                    tree: crate::ctw::FacContextTree::new(*base_depth, bits_per_symbol),
634                    bits_per_symbol,
635                }
636            }
637        }
638    }
639
640    fn reset(&mut self) {
641        match self {
642            TraceModel::Rosa { model, max_order } => {
643                let mut fresh = RosaPlus::new(*max_order, false, 0, 42);
644                fresh.build_lm_full_bytes_no_finalize_endpos();
645                *model = fresh;
646            }
647            TraceModel::Ctw { tree } => tree.clear(),
648            TraceModel::FacCtw { tree, .. } => tree.clear(),
649            TraceModel::Rwkv7 { compressor, primed } => {
650                compressor.state.reset();
651                *primed = false;
652            }
653            TraceModel::Zpaq { model } => {
654                model.reset();
655            }
656            TraceModel::Mixture { backend, model } => {
657                *model = crate::mixture::RateBackendPredictor::from_backend(
658                    backend.clone(),
659                    -1,
660                    2f64.powi(-24),
661                );
662            }
663        }
664    }
665
666    /// Update the model with new data and return the surprise (bits).
667    fn update_and_score(&mut self, data: &[u8]) -> f64 {
668        if data.is_empty() {
669            return 0.0;
670        }
671        match self {
672            TraceModel::Rosa { model, .. } => {
673                let mut bits = 0.0;
674                let mut tx = model.begin_tx();
675                for &b in data {
676                    let p = model.prob_for_last(b as u32).max(1e-12);
677                    bits -= p.log2();
678                    model.train_sequence_tx(&mut tx, &[b]);
679                }
680                bits
681            }
682            TraceModel::Ctw { tree } => {
683                let log_before = tree.get_log_block_probability();
684                for &b in data {
685                    for i in (0..8).rev() {
686                        tree.update(((b >> i) & 1) == 1);
687                    }
688                }
689                let log_after = tree.get_log_block_probability();
690                let log_delta = log_after - log_before;
691                -log_delta / std::f64::consts::LN_2
692            }
693            TraceModel::FacCtw {
694                tree,
695                bits_per_symbol,
696            } => {
697                let log_before = tree.get_log_block_probability();
698                for &b in data {
699                    for i in 0..*bits_per_symbol {
700                        tree.update(((b >> i) & 1) == 1, i);
701                    }
702                }
703                let log_after = tree.get_log_block_probability();
704                let log_delta = log_after - log_before;
705                -log_delta / std::f64::consts::LN_2
706            }
707            TraceModel::Rwkv7 { compressor, primed } => {
708                if !*primed {
709                    let vocab_size = compressor.vocab_size();
710                    let logits =
711                        compressor
712                            .model
713                            .forward(&mut compressor.scratch, 0, &mut compressor.state);
714                    softmax_pdf_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
715                    *primed = true;
716                }
717                let mut bits = 0.0;
718                let vocab_size = compressor.vocab_size();
719                for &b in data {
720                    let p = compressor.pdf_buffer[b as usize].max(1e-12);
721                    bits -= p.log2();
722                    let logits = compressor.model.forward(
723                        &mut compressor.scratch,
724                        b as u32,
725                        &mut compressor.state,
726                    );
727                    softmax_pdf_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
728                }
729                bits
730            }
731            TraceModel::Zpaq { model } => model.update_and_score(data),
732            TraceModel::Mixture { model, .. } => {
733                let mut bits = 0.0;
734                for &b in data {
735                    let logp = model.log_prob(b);
736                    bits -= logp / std::f64::consts::LN_2;
737                    model.update(b);
738                }
739                bits
740            }
741        }
742    }
743}
744
745// ============================================================================
746// Fuzz State
747// ============================================================================
748
749struct FuzzState {
750    current: Vec<u8>,
751    rng: RandomGenerator,
752}
753
754// ============================================================================
755// NyxVmEnvironment
756// ============================================================================
757
758/// High-performance VM environment using nyx-lite.
759pub struct NyxVmEnvironment {
760    /// Configuration.
761    config: NyxVmConfig,
762    /// The nyx-lite VM instance.
763    vm: NyxVM,
764    /// Base snapshot for episode resets.
765    base_snapshot: Option<Arc<NyxSnapshot>>,
766    /// Shared memory virtual address in guest.
767    shared_vaddr: Option<u64>,
768    /// CR3 used when shared memory was registered.
769    shared_cr3: Option<u64>,
770    /// Trace model for entropy-based rewards.
771    trace_model: Option<TraceModel>,
772    /// Baseline entropy for entropy reduction rewards.
773    baseline_entropy: Option<f64>,
774    /// Effective reward shaping policy (additive).
775    reward_shaping: Option<NyxRewardShaping>,
776    /// Fuzzing state.
777    fuzz_state: Option<FuzzState>,
778
779    // Current step state
780    /// Current observation.
781    obs: PerceptVal,
782    /// Current reward.
783    rew: Reward,
784    /// Current observation stream.
785    obs_stream: Vec<PerceptVal>,
786    /// Step within current episode.
787    step_in_episode: usize,
788    /// Whether the environment needs reset.
789    needs_reset: bool,
790    /// Whether the VM has been initialized.
791    initialized: bool,
792}
793
794impl NyxVmEnvironment {
795    /// Creates a new NyxVmEnvironment with the given configuration.
796    pub fn new(config: NyxVmConfig) -> anyhow::Result<Self> {
797        // Validate configuration
798        if config.firecracker_config.is_empty() {
799            return Err(anyhow::anyhow!("firecracker_config path must be set"));
800        }
801        if config.episode_steps == 0 {
802            return Err(anyhow::anyhow!("episode_steps must be > 0"));
803        }
804        if matches!(config.observation_policy, NyxObservationPolicy::RawOutput)
805            && config.observation_stream_len == 0
806        {
807            return Err(anyhow::anyhow!(
808                "observation_stream_len must be > 0 for RawOutput policy"
809            ));
810        }
811
812        // Load Firecracker config and resolve relative paths
813        let fc_config_raw = std::fs::read_to_string(&config.firecracker_config)
814            .map_err(|e| anyhow::anyhow!("Failed to read firecracker config: {}", e))?;
815        let fc_config =
816            rewrite_firecracker_config_paths(&config.firecracker_config, &fc_config_raw)
817                .map_err(|e| anyhow::anyhow!("Failed to parse firecracker config: {}", e))?;
818
819        // Create the VM
820        let vm = NyxVM::new(config.instance_id.clone(), &fc_config);
821
822        // Initialize reward shaping
823        let reward_shaping = config.reward_shaping.clone();
824
825        if matches!(reward_shaping, Some(NyxRewardShaping::TraceEntropy { .. }))
826            && config.trace.is_none()
827        {
828            return Err(anyhow::anyhow!(
829                "vm_trace must be configured for vm_reward_shaping.mode=trace-entropy"
830            ));
831        }
832
833        // Initialize trace model if needed
834        let trace_model = match &reward_shaping {
835            Some(NyxRewardShaping::TraceEntropy { max_order, .. }) => {
836                Some(TraceModel::new(&config.stats_backend, *max_order))
837            }
838            _ => None,
839        };
840
841        // Compute baseline entropy if needed
842        let baseline_entropy = match &reward_shaping {
843            Some(NyxRewardShaping::EntropyReduction {
844                baseline_bytes,
845                max_order,
846                ..
847            }) => {
848                let h = if *max_order == 0 {
849                    marginal_entropy_bytes(baseline_bytes)
850                } else {
851                    entropy_rate_backend(baseline_bytes, *max_order, &config.stats_backend)
852                };
853                Some(h)
854            }
855            _ => None,
856        };
857
858        // Initialize fuzz state if needed
859        let fuzz_state = match &config.action_source {
860            NyxActionSource::Fuzz(fuzz) => {
861                if fuzz.seeds.is_empty() {
862                    return Err(anyhow::anyhow!("Fuzz mode requires at least one seed"));
863                }
864                if fuzz.mutators.is_empty() {
865                    return Err(anyhow::anyhow!("Fuzz mode requires at least one mutator"));
866                }
867                let seed = fuzz.seeds[0].clone();
868                Some(FuzzState {
869                    current: seed,
870                    rng: RandomGenerator::new().fork_with(fuzz.rng_seed),
871                })
872            }
873            NyxActionSource::Literal(actions) => {
874                if actions.is_empty() {
875                    return Err(anyhow::anyhow!("Literal mode requires at least one action"));
876                }
877                None
878            }
879        };
880
881        let mut env = Self {
882            config,
883            vm,
884            base_snapshot: None,
885            shared_vaddr: None,
886            shared_cr3: None,
887            trace_model,
888            baseline_entropy,
889            reward_shaping,
890            fuzz_state,
891            obs: 0,
892            rew: 0,
893            obs_stream: Vec::new(),
894            step_in_episode: 0,
895            needs_reset: true,
896            initialized: false,
897        };
898
899        // Boot and initialize
900        env.initialize()?;
901
902        Ok(env)
903    }
904
905    /// Initializes the VM by booting to the snapshot point.
906    fn initialize(&mut self) -> anyhow::Result<()> {
907        if self.initialized {
908            return Ok(());
909        }
910
911        if self.config.debug_mode {
912            eprintln!("[NyxVm] Booting VM...");
913        }
914
915        // Run until we get the shared memory registration
916        let start = Instant::now();
917        loop {
918            if start.elapsed() > self.config.boot_timeout {
919                return Err(anyhow::anyhow!("Boot timeout waiting for shared memory"));
920            }
921
922            let exit = self.vm.run(Duration::from_secs(1));
923            match exit {
924                ExitReason::SharedMem(name, vaddr, size) => {
925                    if self.config.debug_mode {
926                        eprintln!(
927                            "[NyxVm] Shared memory registered: {} @ {:#x} ({} bytes)",
928                            name, vaddr, size
929                        );
930                    }
931                    if name.trim_end_matches('\0') == self.config.shared_region_name {
932                        self.shared_vaddr = Some(vaddr);
933                        self.shared_cr3 = Some(self.vm.sregs().cr3);
934                        // Register the shared region with the configured policy
935                        let _ = self.vm.register_shared_region_current(
936                            vaddr,
937                            size,
938                            self.config.shared_memory_policy,
939                        );
940                        break;
941                    }
942                }
943                ExitReason::DebugPrint(msg) => {
944                    if self.config.debug_mode {
945                        eprintln!("[NyxVm] Guest: {}", msg);
946                    }
947                }
948                ExitReason::Shutdown => {
949                    return Err(anyhow::anyhow!("VM shut down during boot"));
950                }
951                _ => {
952                    if self.config.debug_mode {
953                        eprintln!("[NyxVm] Boot exit: {:?}", exit);
954                    }
955                    // Continue waiting
956                }
957            }
958        }
959
960        // Continue running until snapshot request
961        loop {
962            if start.elapsed() > self.config.boot_timeout {
963                return Err(anyhow::anyhow!("Boot timeout waiting for snapshot request"));
964            }
965
966            let exit = self.vm.run(Duration::from_secs(1));
967            match exit {
968                ExitReason::RequestSnapshot => {
969                    if self.config.debug_mode {
970                        eprintln!("[NyxVm] Taking base snapshot...");
971                    }
972                    self.base_snapshot = Some(self.vm.take_base_snapshot());
973                    break;
974                }
975                ExitReason::DebugPrint(msg) => {
976                    if self.config.debug_mode {
977                        eprintln!("[NyxVm] Guest: {}", msg);
978                    }
979                }
980                ExitReason::Shutdown => {
981                    return Err(anyhow::anyhow!("VM shut down before snapshot"));
982                }
983                _ => {
984                    if self.config.debug_mode {
985                        eprintln!("[NyxVm] Snapshot wait exit: {:?}", exit);
986                    }
987                    // Continue waiting
988                }
989            }
990        }
991
992        if self.config.debug_mode {
993            eprintln!("[NyxVm] Initialization complete");
994        }
995
996        self.initialized = true;
997        self.needs_reset = false;
998        Ok(())
999    }
1000
1001    /// Resets to the base snapshot.
1002    pub fn reset(&mut self) -> anyhow::Result<()> {
1003        let snapshot = self
1004            .base_snapshot
1005            .as_ref()
1006            .ok_or_else(|| anyhow::anyhow!("No base snapshot available"))?
1007            .clone();
1008
1009        self.vm.apply_snapshot(&snapshot);
1010
1011        // Reset trace model if configured
1012        if let Some(trace_cfg) = &self.config.trace {
1013            if trace_cfg.reset_on_episode {
1014                if let Some(model) = &mut self.trace_model {
1015                    model.reset();
1016                }
1017            }
1018        }
1019
1020        self.step_in_episode = 0;
1021        self.needs_reset = false;
1022
1023        Ok(())
1024    }
1025
1026    /// Writes action data to shared memory.
1027    fn write_action_to_shared_memory(&mut self, payload: &[u8]) -> anyhow::Result<()> {
1028        let vaddr = self
1029            .shared_vaddr
1030            .ok_or_else(|| anyhow::anyhow!("Shared memory not initialized"))?;
1031        let cr3 = self
1032            .shared_cr3
1033            .ok_or_else(|| anyhow::anyhow!("Shared memory CR3 not initialized"))?;
1034        let process = self.vm.process_memory(cr3);
1035
1036        // Ensure guest has cleared the previous message length to avoid races.
1037        let wait_start = Instant::now();
1038        loop {
1039            let cur_len = process
1040                .read_u64(vaddr + SHARED_ACTION_LEN_OFFSET)
1041                .unwrap_or(0);
1042            if cur_len == 0 {
1043                break;
1044            }
1045            if wait_start.elapsed() > self.config.step_timeout {
1046                return Err(anyhow::anyhow!("shared buffer busy (len={cur_len})"));
1047            }
1048            std::thread::yield_now();
1049        }
1050
1051        // Write length as first 8 bytes (u64 LE)
1052        let len = payload.len() as u64;
1053        process
1054            .write_u64(vaddr + SHARED_ACTION_LEN_OFFSET, len)
1055            .map_err(|e| anyhow::anyhow!("write len failed: {e}"))?;
1056        let _ = process.write_u64(vaddr + SHARED_RESP_LEN_OFFSET, 0);
1057
1058        // Write payload starting at offset 8
1059        let max_len = self
1060            .config
1061            .shared_region_size
1062            .saturating_sub(SHARED_PAYLOAD_OFFSET as usize);
1063        let write_len = payload.len().min(max_len);
1064        if write_len > 0 {
1065            let _ = process
1066                .write_bytes(vaddr + SHARED_PAYLOAD_OFFSET, &payload[..write_len])
1067                .map_err(|e| anyhow::anyhow!("write payload failed: {e}"))?;
1068        }
1069
1070        if self.config.debug_mode {
1071            let verify = process
1072                .read_u64(vaddr + SHARED_ACTION_LEN_OFFSET)
1073                .unwrap_or(0) as usize;
1074            eprintln!(
1075                "[NyxVm] Wrote action len={}, verified len={}",
1076                write_len, verify
1077            );
1078        }
1079
1080        Ok(())
1081    }
1082
1083    /// Reads response from shared memory.
1084    fn read_shared_memory(&self) -> Vec<u8> {
1085        let Some(vaddr) = self.shared_vaddr else {
1086            return Vec::new();
1087        };
1088        let Some(cr3) = self.shared_cr3 else {
1089            return Vec::new();
1090        };
1091        let process = self.vm.process_memory(cr3);
1092
1093        // Read length from first 8 bytes
1094        let len = process
1095            .read_u64(vaddr + SHARED_RESP_LEN_OFFSET)
1096            .unwrap_or(0) as usize;
1097        let max_len = self
1098            .config
1099            .shared_region_size
1100            .saturating_sub(SHARED_PAYLOAD_OFFSET as usize);
1101        let read_len = len.min(max_len);
1102
1103        if read_len == 0 {
1104            return Vec::new();
1105        }
1106
1107        let mut buf = vec![0u8; read_len];
1108        let _ = process.read_bytes(vaddr + SHARED_PAYLOAD_OFFSET, &mut buf);
1109        buf
1110    }
1111
1112    fn clear_shared_length(&self) {
1113        let (Some(vaddr), Some(cr3)) = (self.shared_vaddr, self.shared_cr3) else {
1114            return;
1115        };
1116        let process = self.vm.process_memory(cr3);
1117        let _ = process.write_u64(vaddr + SHARED_ACTION_LEN_OFFSET, 0);
1118        let _ = process.write_u64(vaddr + SHARED_RESP_LEN_OFFSET, 0);
1119    }
1120
1121    /// Runs a single step, returning detailed results.
1122    pub fn run_step(&mut self, payload: &[u8]) -> anyhow::Result<NyxStepResult> {
1123        // Write action to shared memory
1124        self.write_action_to_shared_memory(payload)?;
1125
1126        // Run the VM until we get a meaningful exit
1127        let start = Instant::now();
1128        let mut output = Vec::new();
1129        let mut trace_data = Vec::new();
1130        let mut parsed_obs = None;
1131        let mut parsed_rew = None;
1132        let mut done = false;
1133        let exit_kind;
1134        let collect_output =
1135            matches!(
1136                self.config.observation_policy,
1137                NyxObservationPolicy::OutputHash | NyxObservationPolicy::RawOutput
1138            ) || matches!(self.config.reward_policy, NyxRewardPolicy::Pattern { .. })
1139                || matches!(
1140                    self.reward_shaping,
1141                    Some(NyxRewardShaping::EntropyReduction { .. })
1142                );
1143
1144        loop {
1145            let remaining = self
1146                .config
1147                .step_timeout
1148                .checked_sub(start.elapsed())
1149                .unwrap_or(Duration::ZERO);
1150
1151            if remaining.is_zero() {
1152                exit_kind = NyxExitKind::Timeout;
1153                break;
1154            }
1155
1156            let exit = self.vm.run(remaining);
1157            match exit {
1158                ExitReason::ExecDone(code) => {
1159                    exit_kind = NyxExitKind::ExecDone(code);
1160                    done = true;
1161                    break;
1162                }
1163                ExitReason::Timeout => {
1164                    if self.config.debug_mode {
1165                        eprintln!("[NyxVm] Step timeout");
1166                    }
1167                    exit_kind = NyxExitKind::Timeout;
1168                    break;
1169                }
1170                ExitReason::Shutdown => {
1171                    if self.config.debug_mode {
1172                        eprintln!("[NyxVm] VM shutdown during step");
1173                    }
1174                    exit_kind = NyxExitKind::Shutdown;
1175                    done = true;
1176                    break;
1177                }
1178                ExitReason::DebugPrint(msg) => {
1179                    if self.config.debug_mode {
1180                        eprintln!("[NyxVm] Guest: {}", msg);
1181                    }
1182                    // Accumulate debug output
1183                    if collect_output {
1184                        output.extend_from_slice(msg.as_bytes());
1185                    }
1186                    // Continue running
1187                }
1188                ExitReason::Hypercall(r8, r9, r10, r11, r12) => {
1189                    exit_kind = NyxExitKind::Hypercall {
1190                        code: r8,
1191                        arg1: r9,
1192                        arg2: r10,
1193                        arg3: r11,
1194                        arg4: r12,
1195                    };
1196                    // Attempt to parse structured response
1197                    if let Some(obs) = Self::try_parse_u64(r9) {
1198                        parsed_obs = Some(obs);
1199                    }
1200                    if let Some(rew) = Self::try_parse_i64(r10) {
1201                        parsed_rew = Some(rew);
1202                    }
1203                    break;
1204                }
1205                ExitReason::Breakpoint => {
1206                    if self.config.debug_mode {
1207                        eprintln!("[NyxVm] Breakpoint exit during step");
1208                    }
1209                    exit_kind = NyxExitKind::Breakpoint;
1210                    break;
1211                }
1212                _ => {
1213                    // Continue for other exits
1214                }
1215            }
1216        }
1217
1218        // Read shared memory contents (only if needed)
1219        let need_shared_memory = matches!(
1220            self.config.observation_policy,
1221            NyxObservationPolicy::SharedMemory
1222        ) || matches!(
1223            self.config.reward_policy,
1224            NyxRewardPolicy::Pattern { .. }
1225        ) || matches!(
1226            self.reward_shaping,
1227            Some(NyxRewardShaping::EntropyReduction { .. })
1228        ) || self.config.trace.is_some();
1229        let shared_memory = if need_shared_memory {
1230            self.read_shared_memory()
1231        } else {
1232            Vec::new()
1233        };
1234
1235        // Clear shared length to avoid host/guest races on the next step.
1236        self.clear_shared_length();
1237
1238        // Collect trace data if configured
1239        if let Some(trace_cfg) = &self.config.trace {
1240            if trace_cfg.shared_region_name.is_some() {
1241                // Read from trace shared memory region (implementation-specific)
1242                // For now, use main shared memory as fallback
1243                trace_data = shared_memory.clone();
1244                if trace_data.len() > trace_cfg.max_bytes {
1245                    trace_data.truncate(trace_cfg.max_bytes);
1246                }
1247            }
1248        }
1249
1250        Ok(NyxStepResult {
1251            exit_reason: exit_kind,
1252            output,
1253            parsed_obs,
1254            parsed_rew,
1255            done,
1256            trace_data,
1257            shared_memory,
1258        })
1259    }
1260
1261    fn try_parse_u64(val: u64) -> Option<u64> {
1262        // Hypercall args are already u64
1263        Some(val)
1264    }
1265
1266    fn try_parse_i64(val: u64) -> Option<i64> {
1267        Some(val as i64)
1268    }
1269
1270    /// Gets the action payload for the given action index.
1271    fn get_action_payload(&mut self, action: Action) -> anyhow::Result<Cow<'_, [u8]>> {
1272        match &self.config.action_source {
1273            NyxActionSource::Literal(actions) => {
1274                let idx = action as usize;
1275                if idx >= actions.len() {
1276                    return Err(anyhow::anyhow!("Action index out of range"));
1277                }
1278                Ok(Cow::Borrowed(actions[idx].payload.as_slice()))
1279            }
1280            NyxActionSource::Fuzz(fuzz) => {
1281                let state = self
1282                    .fuzz_state
1283                    .as_mut()
1284                    .ok_or_else(|| anyhow::anyhow!("Fuzz state missing"))?;
1285                let idx = action as usize % fuzz.mutators.len();
1286                let mut input = state.current.clone();
1287                let mutator = &fuzz.mutators[idx];
1288                apply_mutator(mutator, &mut input, fuzz, &mut state.rng);
1289                if input.len() < fuzz.min_len {
1290                    input.resize(fuzz.min_len, 0);
1291                }
1292                if input.len() > fuzz.max_len {
1293                    input.truncate(fuzz.max_len);
1294                }
1295                state.current = input.clone();
1296                Ok(Cow::Owned(input))
1297            }
1298        }
1299    }
1300
1301    /// Applies action filtering, returning reject reward if filtered.
1302    fn filter_action(&self, payload: &[u8]) -> Option<i64> {
1303        let filter = self.config.action_filter.as_ref()?;
1304        if payload.is_empty() {
1305            return filter.reject_reward;
1306        }
1307
1308        let (entropy, intrinsic, novelty) = self.compute_filter_metrics(payload, filter);
1309
1310        if let Some(min_entropy) = filter.min_entropy {
1311            if entropy < min_entropy {
1312                return filter.reject_reward;
1313            }
1314        }
1315        if let Some(max_entropy) = filter.max_entropy {
1316            if entropy > max_entropy {
1317                return filter.reject_reward;
1318            }
1319        }
1320        if let Some(min_intrinsic) = filter.min_intrinsic_dependence {
1321            if intrinsic < min_intrinsic {
1322                return filter.reject_reward;
1323            }
1324        }
1325        if let Some(min_novelty) = filter.min_novelty {
1326            if filter.novelty_prior.is_some() && novelty < min_novelty {
1327                return filter.reject_reward;
1328            }
1329        }
1330        None
1331    }
1332
1333    fn wrap_action_payload(&self, payload: &[u8]) -> Vec<u8> {
1334        let p = &self.config.protocol;
1335        let mut wrapped = p.action_prefix.clone().into_bytes();
1336        wrapped.extend_from_slice(p.wire_encoding.encode(payload).as_bytes());
1337        wrapped.extend_from_slice(p.action_suffix.as_bytes());
1338        wrapped
1339    }
1340
1341    fn compute_filter_metrics(&self, payload: &[u8], filter: &NyxActionFilter) -> (f64, f64, f64) {
1342        let h_marg = marginal_entropy_bytes(payload);
1343        let h_rate = if filter.max_order == 0 {
1344            h_marg
1345        } else {
1346            entropy_rate_backend(payload, filter.max_order, &self.config.stats_backend)
1347        };
1348
1349        let intrinsic = if h_marg < 1e-9 {
1350            0.0
1351        } else {
1352            ((h_marg - h_rate) / h_marg).clamp(0.0, 1.0)
1353        };
1354
1355        let novelty = if let Some(ref prior) = filter.novelty_prior {
1356            cross_entropy_rate_backend(payload, prior, filter.max_order, &self.config.stats_backend)
1357        } else {
1358            0.0
1359        };
1360
1361        (h_rate, intrinsic, novelty)
1362    }
1363
1364    /// Computes reward from step result.
1365    fn compute_reward(&mut self, result: &NyxStepResult) -> Reward {
1366        let base_reward = match &self.config.reward_policy {
1367            NyxRewardPolicy::FromGuest => result.parsed_rew.unwrap_or(0),
1368            NyxRewardPolicy::Pattern {
1369                pattern,
1370                base_reward,
1371                bonus_reward,
1372            } => {
1373                let text = String::from_utf8_lossy(&result.output);
1374                let shared_text = String::from_utf8_lossy(&result.shared_memory);
1375                if text.contains(pattern) || shared_text.contains(pattern) {
1376                    base_reward + bonus_reward
1377                } else {
1378                    *base_reward
1379                }
1380            }
1381            NyxRewardPolicy::Custom(f) => f(result),
1382        };
1383
1384        let shaping_reward = if let Some(shaping) = self.reward_shaping.clone() {
1385            self.compute_reward_shaping(&shaping, result)
1386        } else {
1387            0
1388        };
1389
1390        let mut reward = base_reward.saturating_add(shaping_reward);
1391
1392        reward = reward.saturating_sub(self.config.step_cost);
1393        let min_reward = self.min_reward();
1394        let max_reward = self.max_reward();
1395        reward.clamp(min_reward, max_reward)
1396    }
1397
1398    fn compute_reward_shaping(
1399        &mut self,
1400        shaping: &NyxRewardShaping,
1401        result: &NyxStepResult,
1402    ) -> Reward {
1403        match shaping {
1404            NyxRewardShaping::EntropyReduction {
1405                max_order,
1406                scale,
1407                crash_bonus,
1408                timeout_bonus,
1409                ..
1410            } => {
1411                let mut base_reward = {
1412                    let data = if result.shared_memory.is_empty() {
1413                        &result.output
1414                    } else {
1415                        &result.shared_memory
1416                    };
1417                    let h_obs = if *max_order == 0 {
1418                        marginal_entropy_bytes(data)
1419                    } else {
1420                        entropy_rate_backend(data, *max_order, &self.config.stats_backend)
1421                    };
1422                    let h_base = self.baseline_entropy.unwrap_or(0.0);
1423                    let er = (h_base - h_obs) * scale;
1424                    er.round() as i64
1425                };
1426
1427                // Add bonuses for interesting behaviors (bugs/crashes)
1428                match &result.exit_reason {
1429                    NyxExitKind::Shutdown | NyxExitKind::Breakpoint => {
1430                        if let Some(bonus) = crash_bonus {
1431                            base_reward = base_reward.saturating_add(*bonus);
1432                        }
1433                    }
1434                    NyxExitKind::Timeout => {
1435                        if let Some(bonus) = timeout_bonus {
1436                            base_reward = base_reward.saturating_add(*bonus);
1437                        }
1438                    }
1439                    _ => {}
1440                }
1441
1442                base_reward
1443            }
1444            NyxRewardShaping::TraceEntropy {
1445                scale, normalize, ..
1446            } => {
1447                let data = &result.trace_data;
1448                let bits = match self.trace_model.as_mut() {
1449                    Some(model) => model.update_and_score(data),
1450                    None => 0.0,
1451                };
1452                let bits = if *normalize && !data.is_empty() {
1453                    bits / data.len() as f64
1454                } else {
1455                    bits
1456                };
1457                (bits * scale).round() as i64
1458            }
1459        }
1460    }
1461
1462    fn mask_observation(&self, value: u64) -> u64 {
1463        let bits = self.config.observation_bits;
1464        if bits >= 64 {
1465            value
1466        } else if bits == 0 {
1467            0
1468        } else {
1469            value & ((1u64 << bits) - 1)
1470        }
1471    }
1472
1473    fn build_observation_stream(&self, result: &NyxStepResult) -> Vec<PerceptVal> {
1474        let mut observations = match self.config.observation_policy {
1475            NyxObservationPolicy::FromGuest => {
1476                if let Some(obs) = result.parsed_obs {
1477                    vec![self.mask_observation(obs)]
1478                } else {
1479                    vec![self.hash_observation(&result.shared_memory)]
1480                }
1481            }
1482            NyxObservationPolicy::OutputHash => {
1483                vec![self.hash_observation(&result.output)]
1484            }
1485            NyxObservationPolicy::RawOutput => {
1486                result.output.iter().map(|b| *b as PerceptVal).collect()
1487            }
1488            NyxObservationPolicy::SharedMemory => result
1489                .shared_memory
1490                .iter()
1491                .map(|b| *b as PerceptVal)
1492                .collect(),
1493        };
1494
1495        if observations.is_empty() {
1496            observations.push(0);
1497        }
1498
1499        self.normalize_observation_stream(&mut observations);
1500        observations
1501    }
1502
1503    fn hash_observation(&self, data: &[u8]) -> PerceptVal {
1504        let h = robust_hash_bytes(data);
1505        self.mask_observation(h)
1506    }
1507
1508    fn normalize_observation_stream(&self, observations: &mut Vec<PerceptVal>) {
1509        let mask = if self.config.observation_bits >= 64 {
1510            u64::MAX
1511        } else if self.config.observation_bits == 0 {
1512            0
1513        } else {
1514            (1u64 << self.config.observation_bits) - 1
1515        };
1516
1517        for obs in observations.iter_mut() {
1518            *obs &= mask;
1519        }
1520
1521        let target = self.config.observation_stream_len;
1522        if target == 0 {
1523            return;
1524        }
1525
1526        if observations.len() > target {
1527            match self.config.observation_stream_mode {
1528                NyxObservationStreamMode::Truncate | NyxObservationStreamMode::PadTruncate => {
1529                    observations.truncate(target);
1530                }
1531                NyxObservationStreamMode::Pad => {}
1532            }
1533        } else if observations.len() < target {
1534            match self.config.observation_stream_mode {
1535                NyxObservationStreamMode::Pad | NyxObservationStreamMode::PadTruncate => {
1536                    let pad = self.config.observation_pad_byte as PerceptVal;
1537                    observations.resize(target, pad);
1538                }
1539                NyxObservationStreamMode::Truncate => {}
1540            }
1541        }
1542    }
1543
1544    fn action_count(&self) -> usize {
1545        match &self.config.action_source {
1546            NyxActionSource::Literal(actions) => actions.len(),
1547            NyxActionSource::Fuzz(fuzz) => fuzz.mutators.len(),
1548        }
1549    }
1550
1551    /// Direct access to the underlying NyxVM for advanced use cases.
1552    pub fn vm(&self) -> &NyxVM {
1553        &self.vm
1554    }
1555
1556    /// Mutable access to the underlying NyxVM.
1557    pub fn vm_mut(&mut self) -> &mut NyxVM {
1558        &mut self.vm
1559    }
1560
1561    /// Takes a new snapshot at the current state.
1562    pub fn take_snapshot(&mut self) -> Arc<NyxSnapshot> {
1563        self.vm.take_snapshot()
1564    }
1565
1566    /// Applies a specific snapshot.
1567    pub fn apply_snapshot(&mut self, snapshot: &Arc<NyxSnapshot>) {
1568        self.vm.apply_snapshot(snapshot);
1569    }
1570
1571    /// Resets trace model.
1572    pub fn reset_trace_model(&mut self) {
1573        if let Some(model) = &mut self.trace_model {
1574            model.reset();
1575        }
1576    }
1577
1578    /// Logs crashes and interesting behaviors to file.
1579    fn log_crash(&self, action_payload: &[u8], result: &NyxStepResult, reward: i64) {
1580        let Some(log_path) = &self.config.crash_log else {
1581            return;
1582        };
1583
1584        // Only log interesting exits
1585        let is_interesting = matches!(
1586            result.exit_reason,
1587            NyxExitKind::Shutdown | NyxExitKind::Breakpoint | NyxExitKind::Timeout
1588        );
1589
1590        if !is_interesting {
1591            return;
1592        }
1593
1594        let log_entry = serde_json::json!({
1595            "timestamp": std::time::SystemTime::now()
1596                .duration_since(std::time::UNIX_EPOCH)
1597                .unwrap_or_default()
1598                .as_secs(),
1599            "exit_reason": format!("{:?}", result.exit_reason),
1600            "action_payload": hex_encode(action_payload),
1601            "action_payload_str": String::from_utf8_lossy(action_payload),
1602            "output": String::from_utf8_lossy(&result.output),
1603            "shared_memory": hex_encode(&result.shared_memory),
1604            "reward": reward,
1605            "parsed_obs": result.parsed_obs,
1606            "parsed_rew": result.parsed_rew,
1607        });
1608
1609        // Append to JSONL file
1610        if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(log_path) {
1611            if let Ok(json_str) = serde_json::to_string(&log_entry) {
1612                let _ = writeln!(file, "{}", json_str);
1613            }
1614        }
1615    }
1616}
1617
1618// ============================================================================
1619// Environment Trait Implementation
1620// ============================================================================
1621
1622impl Environment for NyxVmEnvironment {
1623    fn perform_action(&mut self, action: Action) {
1624        if self.needs_reset {
1625            if let Err(e) = self.reset() {
1626                if self.config.debug_mode {
1627                    eprintln!("[NyxVm] Reset failed: {}", e);
1628                }
1629            }
1630        }
1631
1632        let payload = match self.get_action_payload(action) {
1633            Ok(payload) => payload.into_owned(),
1634            Err(e) => {
1635                if self.config.debug_mode {
1636                    eprintln!("[NyxVm] Invalid action: {}", e);
1637                }
1638                self.obs = 0;
1639                self.rew = self.min_reward();
1640                self.obs_stream.clear();
1641                self.obs_stream.push(0);
1642                self.step_in_episode = (self.step_in_episode + 1) % self.config.episode_steps;
1643                if self.step_in_episode == 0 {
1644                    self.needs_reset = true;
1645                }
1646                return;
1647            }
1648        };
1649
1650        // Check action filter
1651        if let Some(reject_reward) = self.filter_action(&payload) {
1652            self.obs = 0;
1653            self.rew = reject_reward.clamp(self.min_reward(), self.max_reward());
1654            self.obs_stream.clear();
1655            self.obs_stream.push(0);
1656            self.step_in_episode = (self.step_in_episode + 1) % self.config.episode_steps;
1657            if self.step_in_episode == 0 {
1658                self.needs_reset = true;
1659            }
1660            return;
1661        }
1662
1663        // Run the step
1664        let wrapped_payload = self.wrap_action_payload(&payload);
1665        let result = match self.run_step(&wrapped_payload) {
1666            Ok(result) => result,
1667            Err(e) => {
1668                if self.config.debug_mode {
1669                    eprintln!("[NyxVm] Step failed: {}", e);
1670                }
1671                self.obs = 0;
1672                self.rew = self.min_reward();
1673                self.obs_stream.clear();
1674                self.obs_stream.push(0);
1675                self.step_in_episode = (self.step_in_episode + 1) % self.config.episode_steps;
1676                if self.step_in_episode == 0 {
1677                    self.needs_reset = true;
1678                }
1679                return;
1680            }
1681        };
1682
1683        // Process results
1684        self.obs_stream = self.build_observation_stream(&result);
1685        self.obs = self.obs_stream.first().copied().unwrap_or(0);
1686        self.rew = self.compute_reward(&result);
1687
1688        // Log crashes and interesting behaviors
1689        self.log_crash(&payload, &result, self.rew);
1690
1691        if self.config.debug_mode {
1692            eprintln!(
1693                "[NyxVm] Action={} Obs={} Rew={} Done={:?} Exit={:?}",
1694                action, self.obs, self.rew, result.done, result.exit_reason
1695            );
1696        }
1697
1698        self.step_in_episode = (self.step_in_episode + 1) % self.config.episode_steps;
1699        if self.step_in_episode == 0 || result.done {
1700            self.needs_reset = true;
1701        }
1702    }
1703
1704    fn get_observation(&self) -> PerceptVal {
1705        self.obs
1706    }
1707
1708    fn drain_observations(&mut self) -> Vec<PerceptVal> {
1709        if self.obs_stream.is_empty() {
1710            vec![self.obs]
1711        } else {
1712            std::mem::take(&mut self.obs_stream)
1713        }
1714    }
1715
1716    fn get_reward(&self) -> Reward {
1717        self.rew
1718    }
1719
1720    fn is_finished(&self) -> bool {
1721        false
1722    }
1723
1724    fn get_observation_bits(&self) -> usize {
1725        self.config.observation_bits
1726    }
1727
1728    fn get_reward_bits(&self) -> usize {
1729        self.config.reward_bits
1730    }
1731
1732    fn get_action_bits(&self) -> usize {
1733        let n = self.action_count();
1734        if n <= 1 {
1735            return 1;
1736        }
1737        (n as f64).log2().ceil() as usize
1738    }
1739
1740    fn get_num_actions(&self) -> usize {
1741        self.action_count()
1742    }
1743
1744    fn max_reward(&self) -> Reward {
1745        let bits = self.config.reward_bits;
1746        if bits >= 64 {
1747            i64::MAX
1748        } else if bits == 0 {
1749            0
1750        } else {
1751            (1i64 << (bits - 1)) - 1
1752        }
1753    }
1754
1755    fn min_reward(&self) -> Reward {
1756        let bits = self.config.reward_bits;
1757        if bits >= 64 {
1758            i64::MIN
1759        } else if bits == 0 {
1760            0
1761        } else {
1762            -(1i64 << (bits - 1))
1763        }
1764    }
1765}
1766
1767// ============================================================================
1768// Helper Functions
1769// ============================================================================
1770
1771fn robust_hash_bytes(data: &[u8]) -> u64 {
1772    let mut h = 0u64;
1773    for &b in data {
1774        h = h.rotate_left(7) ^ (b as u64);
1775    }
1776    h
1777}
1778
1779fn apply_mutator(
1780    mutator: &FuzzMutator,
1781    input: &mut Vec<u8>,
1782    fuzz: &NyxFuzzConfig,
1783    rng: &mut RandomGenerator,
1784) {
1785    match mutator {
1786        FuzzMutator::FlipBit => {
1787            if input.is_empty() {
1788                input.push(0);
1789            }
1790            let idx = rng.gen_range(input.len());
1791            let bit = rng.gen_range(8);
1792            input[idx] ^= 1u8 << bit;
1793        }
1794        FuzzMutator::FlipByte => {
1795            if input.is_empty() {
1796                input.push(0);
1797            }
1798            let idx = rng.gen_range(input.len());
1799            input[idx] ^= rng.next_u64() as u8;
1800        }
1801        FuzzMutator::InsertByte => {
1802            let idx = if input.is_empty() {
1803                0
1804            } else {
1805                rng.gen_range(input.len() + 1)
1806            };
1807            let byte = if !fuzz.dictionary.is_empty() {
1808                let d = rng.gen_range(fuzz.dictionary.len());
1809                let entry = &fuzz.dictionary[d];
1810                if entry.is_empty() {
1811                    0
1812                } else {
1813                    entry[rng.gen_range(entry.len())]
1814                }
1815            } else {
1816                rng.next_u64() as u8
1817            };
1818            input.insert(idx, byte);
1819        }
1820        FuzzMutator::DeleteByte => {
1821            if input.len() > 1 {
1822                let idx = rng.gen_range(input.len());
1823                input.remove(idx);
1824            }
1825        }
1826        FuzzMutator::SpliceSeed => {
1827            if fuzz.seeds.is_empty() {
1828                return;
1829            }
1830            let seed = &fuzz.seeds[rng.gen_range(fuzz.seeds.len())];
1831            if input.is_empty() {
1832                input.extend_from_slice(seed);
1833            } else if !seed.is_empty() {
1834                let cut = rng.gen_range(input.len());
1835                let seed_cut = rng.gen_range(seed.len());
1836                let mut out = Vec::new();
1837                out.extend_from_slice(&input[..cut]);
1838                out.extend_from_slice(&seed[seed_cut..]);
1839                *input = out;
1840            }
1841        }
1842        FuzzMutator::ResetSeed => {
1843            if fuzz.seeds.is_empty() {
1844                return;
1845            }
1846            *input = fuzz.seeds[rng.gen_range(fuzz.seeds.len())].clone();
1847        }
1848        FuzzMutator::Havoc => {
1849            let flips = 1 + rng.gen_range(8);
1850            for _ in 0..flips {
1851                if input.is_empty() {
1852                    input.push(0);
1853                }
1854                let idx = rng.gen_range(input.len());
1855                input[idx] ^= rng.next_u64() as u8;
1856            }
1857        }
1858    }
1859}
1860
1861// ============================================================================
1862// Tests
1863// ============================================================================
1864
1865#[cfg(test)]
1866mod tests {
1867    use super::*;
1868
1869    #[test]
1870    fn test_hex_encoding() {
1871        let data = b"hello";
1872        let encoded = hex_encode(data);
1873        assert_eq!(encoded, "68656c6c6f");
1874        let decoded = hex_decode(&encoded).unwrap();
1875        assert_eq!(decoded, data);
1876    }
1877
1878    #[test]
1879    fn test_robust_hash() {
1880        let data1 = b"test data";
1881        let data2 = b"test data";
1882        let data3 = b"different";
1883
1884        assert_eq!(robust_hash_bytes(data1), robust_hash_bytes(data2));
1885        assert_ne!(robust_hash_bytes(data1), robust_hash_bytes(data3));
1886    }
1887
1888    #[test]
1889    fn test_payload_encoding() {
1890        let utf8 = PayloadEncoding::Utf8;
1891        let hex = PayloadEncoding::Hex;
1892
1893        let data = b"test";
1894        assert_eq!(utf8.encode(data), "test");
1895        assert_eq!(hex.encode(data), "74657374");
1896
1897        assert_eq!(utf8.decode("test").unwrap(), data);
1898        assert_eq!(hex.decode("74657374").unwrap(), data);
1899    }
1900}