rwkvzip/rwkv7/training/
train.rs

1use super::data::Enwik8Mmap;
2use super::export::export_safetensors;
3use super::model::TrainModelConfig;
4use super::model_fast::{FastTrainModel, FastTrainState};
5use super::validate::validate_roundtrip;
6use anyhow::{bail, Context, Result};
7use std::path::PathBuf;
8use std::time::Instant;
9use tch::{no_grad, Device, Tensor};
10
11// Helper functions for optimizer state saving/loading
12fn tensor_to_vec_f32(t: &Tensor) -> Result<Vec<f32>> {
13    let t = t
14        .to_device(Device::Cpu)
15        .to_kind(tch::Kind::Float)
16        .contiguous();
17    let n = t.numel();
18    let mut out = vec![0f32; n as usize];
19    t.f_copy_data(&mut out, n)
20        .context("Failed to copy tensor data")?;
21    Ok(out)
22}
23
24fn tensor_shape_usize(t: &Tensor) -> Vec<usize> {
25    t.size().iter().map(|&d| d as usize).collect()
26}
27
28#[derive(Debug, Clone)]
29pub struct TrainConfig {
30    pub dataset_path: PathBuf,
31    pub output_model_path: PathBuf,
32    pub resume_model_path: Option<PathBuf>,
33
34    pub steps: usize,
35    pub batch_size: i64,
36    pub seq_len: i64,
37    pub grad_accum_steps: usize,
38
39    pub lr: f64,
40    pub weight_decay: f64,
41    pub beta1: f64,
42    pub beta2: f64,
43    pub adam_eps: f64,
44
45    pub seed: u64,
46    pub device: Option<String>,
47
48    pub validate_roundtrip_path: Option<PathBuf>,
49
50    pub model_cfg: TrainModelConfig,
51}
52
53impl Default for TrainConfig {
54    fn default() -> Self {
55        Self {
56            dataset_path: PathBuf::from("files/enwik8"),
57            output_model_path: PathBuf::from("out/rwkv7_byte_small.safetensors"),
58            resume_model_path: None,
59            steps: 2_000,
60            batch_size: 32,
61            seq_len: 128,
62            grad_accum_steps: 1,
63            lr: 2e-4,
64            weight_decay: 0.1,
65            beta1: 0.9,
66            beta2: 0.99,
67            adam_eps: 1e-8,
68            seed: 42,
69            device: None,
70            validate_roundtrip_path: Some(PathBuf::from("files/bench.txt")),
71            model_cfg: TrainModelConfig::small_default(),
72        }
73    }
74}
75
76#[derive(Debug, Clone)]
77pub struct TrainReport {
78    pub device: String,
79    pub steps: usize,
80    pub tokens_per_step: i64,
81    pub final_loss: f64,
82    pub steps_per_sec: f64,
83    pub tokens_per_sec: f64,
84    pub output_model_path: PathBuf,
85}
86
87struct AdamWEntry {
88    p: Tensor,
89    m: Tensor,
90    v: Tensor,
91}
92
93struct AdamW {
94    entries: Vec<AdamWEntry>,
95    beta1: f64,
96    beta2: f64,
97    eps: f64,
98    weight_decay: f64,
99    step: i64,
100}
101
102impl AdamW {
103    fn new(params: Vec<Tensor>, beta1: f64, beta2: f64, eps: f64, weight_decay: f64) -> Self {
104        let mut entries = Vec::with_capacity(params.len());
105        for p in params {
106            let m = Tensor::zeros_like(&p);
107            let v = Tensor::zeros_like(&p);
108            entries.push(AdamWEntry { p, m, v });
109        }
110        Self {
111            entries,
112            beta1,
113            beta2,
114            eps,
115            weight_decay,
116            step: 0,
117        }
118    }
119
120    /// Create a new AdamW optimizer and optionally load state from a checkpoint
121    fn new_with_state(
122        params: Vec<Tensor>,
123        beta1: f64,
124        beta2: f64,
125        eps: f64,
126        weight_decay: f64,
127        checkpoint_path: Option<&std::path::Path>,
128    ) -> Result<Self> {
129        let mut optimizer = Self::new(params, beta1, beta2, eps, weight_decay);
130
131        // Try to load optimizer state if checkpoint is provided
132        if let Some(path) = checkpoint_path {
133            if path.exists() {
134                optimizer.load_state(path)?;
135                println!("Loaded optimizer state from checkpoint");
136            } else {
137                println!(
138                    "Warning: Optimizer checkpoint not found, starting with fresh optimizer state"
139                );
140            }
141        }
142
143        Ok(optimizer)
144    }
145
146    /// Save optimizer state to a checkpoint file
147    fn save_state(&self, path: &std::path::Path) -> Result<()> {
148        use std::collections::BTreeMap;
149        use std::fs;
150        use std::io::Write;
151
152        let mut tensors: BTreeMap<String, (Vec<usize>, Vec<f32>)> = BTreeMap::new();
153
154        // Save step count
155        tensors.insert(
156            "optimizer.step".to_string(),
157            (vec![1], vec![self.step as f32]),
158        );
159
160        // Save momentum and variance for each parameter
161        for (i, entry) in self.entries.iter().enumerate() {
162            let prefix = format!("optimizer.{}", i);
163
164            // Save momentum (m)
165            tensors.insert(
166                format!("{}.m", prefix),
167                (tensor_shape_usize(&entry.m), tensor_to_vec_f32(&entry.m)?),
168            );
169
170            // Save variance (v)
171            tensors.insert(
172                format!("{}.v", prefix),
173                (tensor_shape_usize(&entry.v), tensor_to_vec_f32(&entry.v)?),
174            );
175        }
176
177        // Write in the same format as safetensors
178        let mut cursor: usize = 0;
179        let mut meta_entries: Vec<String> = Vec::with_capacity(tensors.len());
180        let mut data_blobs: Vec<Vec<u8>> = Vec::with_capacity(tensors.len());
181
182        for (name, (shape, data_f32)) in tensors.into_iter() {
183            let byte_len = data_f32.len() * 4;
184            let start = cursor;
185            let end = cursor + byte_len;
186            cursor = end;
187
188            let mut bytes = Vec::with_capacity(byte_len);
189            for v in data_f32 {
190                bytes.extend_from_slice(&v.to_le_bytes());
191            }
192            data_blobs.push(bytes);
193
194            let shape_json = shape
195                .iter()
196                .map(|d| d.to_string())
197                .collect::<Vec<_>>()
198                .join(",");
199            let entry = format!(
200                "\"{}\":{{\"dtype\":\"F32\",\"shape\":[{}],\"data_offsets\":[{},{}]}}",
201                name, shape_json, start, end
202            );
203            meta_entries.push(entry);
204        }
205
206        let header_json = format!("{{\"__metadata__\":{{}},{} }}", meta_entries.join(","));
207        let header_bytes = header_json.as_bytes();
208        let header_len = header_bytes.len() as u64;
209
210        let mut out = Vec::with_capacity(8 + header_bytes.len() + cursor);
211        out.extend_from_slice(&header_len.to_le_bytes());
212        out.extend_from_slice(header_bytes);
213        for blob in data_blobs {
214            out.extend_from_slice(&blob);
215        }
216
217        fs::write(path, out).with_context(|| {
218            format!("Failed to write optimizer checkpoint to {}", path.display())
219        })?;
220
221        Ok(())
222    }
223
224    /// Load optimizer state from a checkpoint file
225    fn load_state(&mut self, path: &std::path::Path) -> Result<()> {
226        use std::fs::File;
227        use std::io::Read;
228
229        let mut file = File::open(path)
230            .with_context(|| format!("Failed to open optimizer checkpoint: {}", path.display()))?;
231
232        // Read header length
233        let mut header_len_bytes = [0u8; 8];
234        file.read_exact(&mut header_len_bytes)?;
235        let header_len = u64::from_le_bytes(header_len_bytes) as usize;
236
237        // Read JSON header
238        let mut header_bytes = vec![0u8; header_len];
239        file.read_exact(&mut header_bytes)?;
240        let header_str = std::str::from_utf8(&header_bytes)
241            .context("Invalid UTF-8 in optimizer checkpoint header")?;
242
243        // Parse JSON header
244        let json: serde_json::Value = serde_json::from_str(header_str)
245            .context("Failed to parse optimizer checkpoint JSON header")?;
246
247        let data_offset = 8 + header_len;
248
249        // Load step count
250        if let Some(step_info) = json.get("optimizer.step") {
251            let offsets = step_info["data_offsets"]
252                .as_array()
253                .ok_or_else(|| anyhow::anyhow!("Invalid data_offsets for step"))?
254                .iter()
255                .map(|v| v.as_u64().unwrap() as usize)
256                .collect::<Vec<usize>>();
257
258            let byte_len = offsets[1] - offsets[0];
259            let mut raw_bytes = vec![0u8; byte_len];
260            use std::io::Seek;
261            file.seek(std::io::SeekFrom::Start((data_offset + offsets[0]) as u64))?;
262            file.read_exact(&mut raw_bytes)?;
263
264            // Convert bytes to f32 and then to i64
265            let step_f32 =
266                f32::from_le_bytes([raw_bytes[0], raw_bytes[1], raw_bytes[2], raw_bytes[3]]);
267            self.step = step_f32 as i64;
268        }
269
270        // Load momentum and variance for each parameter
271        for (i, entry) in self.entries.iter_mut().enumerate() {
272            let prefix = format!("optimizer.{}", i);
273
274            // Load momentum (m)
275            if let Some(m_info) = json.get(&format!("{}.m", prefix)) {
276                let shape = m_info["shape"]
277                    .as_array()
278                    .ok_or_else(|| anyhow::anyhow!("Invalid shape for momentum"))?
279                    .iter()
280                    .map(|v| v.as_u64().unwrap() as i64)
281                    .collect::<Vec<i64>>();
282
283                let offsets = m_info["data_offsets"]
284                    .as_array()
285                    .ok_or_else(|| anyhow::anyhow!("Invalid data_offsets for momentum"))?
286                    .iter()
287                    .map(|v| v.as_u64().unwrap() as usize)
288                    .collect::<Vec<usize>>();
289
290                let byte_len = offsets[1] - offsets[0];
291                let mut raw_bytes = vec![0u8; byte_len];
292                use std::io::Seek;
293                file.seek(std::io::SeekFrom::Start((data_offset + offsets[0]) as u64))?;
294                file.read_exact(&mut raw_bytes)?;
295
296                let mut data = vec![0f32; byte_len / 4];
297                for j in 0..data.len() {
298                    let base = j * 4;
299                    data[j] = f32::from_le_bytes([
300                        raw_bytes[base],
301                        raw_bytes[base + 1],
302                        raw_bytes[base + 2],
303                        raw_bytes[base + 3],
304                    ]);
305                }
306
307                entry.m = Tensor::f_from_slice(&data)?
308                    .view(&*shape)
309                    .to_device(entry.m.device());
310            }
311
312            // Load variance (v)
313            if let Some(v_info) = json.get(&format!("{}.v", prefix)) {
314                let shape = v_info["shape"]
315                    .as_array()
316                    .ok_or_else(|| anyhow::anyhow!("Invalid shape for variance"))?
317                    .iter()
318                    .map(|v| v.as_u64().unwrap() as i64)
319                    .collect::<Vec<i64>>();
320
321                let offsets = v_info["data_offsets"]
322                    .as_array()
323                    .ok_or_else(|| anyhow::anyhow!("Invalid data_offsets for variance"))?
324                    .iter()
325                    .map(|v| v.as_u64().unwrap() as usize)
326                    .collect::<Vec<usize>>();
327
328                let byte_len = offsets[1] - offsets[0];
329                let mut raw_bytes = vec![0u8; byte_len];
330                use std::io::Seek;
331                file.seek(std::io::SeekFrom::Start((data_offset + offsets[0]) as u64))?;
332                file.read_exact(&mut raw_bytes)?;
333
334                let mut data = vec![0f32; byte_len / 4];
335                for j in 0..data.len() {
336                    let base = j * 4;
337                    data[j] = f32::from_le_bytes([
338                        raw_bytes[base],
339                        raw_bytes[base + 1],
340                        raw_bytes[base + 2],
341                        raw_bytes[base + 3],
342                    ]);
343                }
344
345                entry.v = Tensor::f_from_slice(&data)?
346                    .view(&*shape)
347                    .to_device(entry.v.device());
348            }
349        }
350
351        Ok(())
352    }
353
354    fn zero_grad(&mut self) {
355        for e in &mut self.entries {
356            e.p.zero_grad();
357        }
358    }
359
360    fn step(&mut self, lr: f64) {
361        self.step += 1;
362        let t = self.step as f64;
363        let bias_c1 = 1.0 - self.beta1.powf(t);
364        let bias_c2 = 1.0 - self.beta2.powf(t);
365        let step_size = lr * (bias_c2.sqrt() / bias_c1);
366
367        no_grad(|| {
368            for e in &mut self.entries {
369                let g = e.p.grad();
370                if !g.defined() {
371                    continue;
372                }
373
374                // m = b1*m + (1-b1)*g
375                e.m = &e.m * self.beta1 + &g * (1.0 - self.beta1);
376                // v = b2*v + (1-b2)*g^2
377                e.v = &e.v * self.beta2 + g.square() * (1.0 - self.beta2);
378
379                // denom = sqrt(v) + eps
380                let denom = e.v.sqrt() + self.eps;
381
382                // Decoupled weight decay
383                if self.weight_decay != 0.0 {
384                    let wd = &e.p * (-lr * self.weight_decay);
385                    // p += (-lr*wd) * p
386                    let _ = e.p.f_add_(&wd);
387                }
388
389                // p -= step_size * m / denom
390                let update = (&e.m / denom) * (-step_size);
391                let _ = e.p.f_add_(&update);
392            }
393        });
394
395        self.zero_grad();
396    }
397}
398
399fn pick_device(requested: &Option<String>) -> Result<Device> {
400    let cuda_available = tch::Cuda::is_available();
401
402    if let Some(s) = requested.as_deref() {
403        match s {
404            "cpu" => return Ok(Device::Cpu),
405            "cuda" | "gpu" => {
406                if cuda_available {
407                    return Ok(Device::Cuda(0));
408                }
409                bail!("CUDA was requested but is not available. Set TORCH_CUDA_VERSION=121 (or 118) and rebuild, or pass --device cpu to force CPU.");
410            }
411            other => bail!("Unknown device '{}'. Use cpu or cuda", other),
412        }
413    }
414
415    if cuda_available {
416        Ok(Device::Cuda(0))
417    } else {
418        bail!(
419            "CUDA not available. Training expects a CUDA-enabled libtorch. Install/rebuild with TORCH_CUDA_VERSION=121 (or 118) or rerun with --device cpu explicitly."
420        )
421    }
422}
423
424pub fn train_enwik8(cfg: TrainConfig) -> Result<TrainReport> {
425    cfg.model_cfg.validate()?;
426
427    let device = pick_device(&cfg.device)?;
428    let device_str = match device {
429        Device::Cpu => "cpu".to_string(),
430        Device::Cuda(_) => {
431            // Enable cuDNN benchmark mode for better performance
432            tch::Cuda::cudnn_set_benchmark(true);
433            "cuda:0".to_string()
434        }
435        _ => format!("{:?}", device),
436    };
437
438    // Ensure output dir exists
439    if let Some(parent) = cfg.output_model_path.parent() {
440        std::fs::create_dir_all(parent)
441            .with_context(|| format!("create output dir {}", parent.display()))?;
442    }
443
444    let ds = Enwik8Mmap::open(&cfg.dataset_path)?;
445    let mut rng = Enwik8Mmap::seeded_rng(cfg.seed);
446
447    // Create model - either from scratch or by loading existing weights
448    let model = if let Some(resume_path) = &cfg.resume_model_path {
449        println!("Loading model weights from {}...", resume_path.display());
450        FastTrainModel::load_from_safetensors(resume_path, cfg.model_cfg.clone(), device)
451            .with_context(|| format!("Failed to load model from {}", resume_path.display()))?
452    } else {
453        FastTrainModel::new(cfg.model_cfg.clone(), device, cfg.seed as i64)?
454    };
455
456    let mut state = FastTrainState::new(&cfg.model_cfg, cfg.batch_size, device)?;
457
458    let params = model.parameters();
459
460    // Determine optimizer checkpoint path
461    let optimizer_checkpoint_path = cfg.output_model_path.with_extension("opt.safetensors");
462
463    // Create optimizer checkpoint path for loading (if resuming)
464    let resume_optimizer_path = cfg
465        .resume_model_path
466        .as_ref()
467        .map(|p| p.with_extension("opt.safetensors"));
468
469    let mut opt = AdamW::new_with_state(
470        params,
471        cfg.beta1,
472        cfg.beta2,
473        cfg.adam_eps,
474        cfg.weight_decay,
475        resume_optimizer_path.as_deref(),
476    )?;
477
478    let accum_steps = cfg.grad_accum_steps.max(1);
479    let tokens_per_step = cfg.batch_size * cfg.seq_len * (accum_steps as i64);
480    let start = Instant::now();
481    let mut last_loss = 0.0f64;
482
483    for step in 0..cfg.steps {
484        let mut step_loss_sum = 0.0f64;
485
486        for micro_step in 0..accum_steps {
487            state.reset();
488
489            let (x, y) = ds.sample_batch(&mut rng, cfg.batch_size, cfg.seq_len, device)?;
490
491            // Forward entire sequence at once (sequence-parallel)
492            let logits = model.forward_sequence(&x, &mut state)?; // (B, T, V)
493            let targets = y; // (B,T)
494
495            // Cross-entropy over all positions, normalized by accum_steps for correct gradient scale
496            let loss = logits
497                .view([-1, cfg.model_cfg.vocab_size])
498                .cross_entropy_for_logits(&targets.view([-1]))
499                / (accum_steps as f64);
500
501            loss.backward();
502
503            // Accumulate loss for logging
504            step_loss_sum += loss.to_device(Device::Cpu).double_value(&[]) * (accum_steps as f64);
505
506            // Only step optimizer on last micro-step
507            if micro_step == accum_steps - 1 {
508                opt.step(cfg.lr);
509            }
510        }
511
512        last_loss = step_loss_sum / (accum_steps as f64);
513
514        if step % 50 == 0 {
515            let elapsed = start.elapsed().as_secs_f64().max(1e-9);
516            let sps = (step.max(1) as f64) / elapsed;
517            let tps = sps * (tokens_per_step as f64);
518            eprintln!(
519                "step {:6} | loss {:8.4} | {:8.2} steps/s | {:10.0} tok/s | {}",
520                step, last_loss, sps, tps, device_str
521            );
522
523            // Save optimizer checkpoint every 50 steps
524            if let Err(e) = opt.save_state(&optimizer_checkpoint_path) {
525                eprintln!("Warning: Failed to save optimizer checkpoint: {}", e);
526            }
527        }
528    }
529
530    // Export weights
531    export_safetensors(&cfg.output_model_path, &cfg.model_cfg, &model.p)
532        .context("export safetensors")?;
533
534    // Save final optimizer state
535    if let Err(e) = opt.save_state(&optimizer_checkpoint_path) {
536        eprintln!("Warning: Failed to save final optimizer checkpoint: {}", e);
537    }
538
539    // Optional end-to-end compression validation.
540    if let Some(p) = &cfg.validate_roundtrip_path {
541        validate_roundtrip(&cfg.output_model_path, p)
542            .with_context(|| format!("validate roundtrip using {}", p.display()))?;
543    }
544
545    let elapsed = start.elapsed().as_secs_f64().max(1e-9);
546    let steps_per_sec = (cfg.steps as f64) / elapsed;
547    let tokens_per_sec = steps_per_sec * (tokens_per_step as f64);
548
549    Ok(TrainReport {
550        device: device_str,
551        steps: cfg.steps,
552        tokens_per_step,
553        final_loss: last_loss,
554        steps_per_sec,
555        tokens_per_sec,
556        output_model_path: cfg.output_model_path,
557    })
558}