1use super::data::Enwik8Mmap;
2use super::export::export_safetensors;
3use super::model::TrainModelConfig;
4use super::model_fast::{FastTrainModel, FastTrainState};
5use super::validate::validate_roundtrip;
6use anyhow::{bail, Context, Result};
7use std::path::PathBuf;
8use std::time::Instant;
9use tch::{no_grad, Device, Tensor};
10
11fn tensor_to_vec_f32(t: &Tensor) -> Result<Vec<f32>> {
13 let t = t
14 .to_device(Device::Cpu)
15 .to_kind(tch::Kind::Float)
16 .contiguous();
17 let n = t.numel();
18 let mut out = vec![0f32; n as usize];
19 t.f_copy_data(&mut out, n)
20 .context("Failed to copy tensor data")?;
21 Ok(out)
22}
23
24fn tensor_shape_usize(t: &Tensor) -> Vec<usize> {
25 t.size().iter().map(|&d| d as usize).collect()
26}
27
28#[derive(Debug, Clone)]
29pub struct TrainConfig {
30 pub dataset_path: PathBuf,
31 pub output_model_path: PathBuf,
32 pub resume_model_path: Option<PathBuf>,
33
34 pub steps: usize,
35 pub batch_size: i64,
36 pub seq_len: i64,
37 pub grad_accum_steps: usize,
38
39 pub lr: f64,
40 pub weight_decay: f64,
41 pub beta1: f64,
42 pub beta2: f64,
43 pub adam_eps: f64,
44
45 pub seed: u64,
46 pub device: Option<String>,
47
48 pub validate_roundtrip_path: Option<PathBuf>,
49
50 pub model_cfg: TrainModelConfig,
51}
52
53impl Default for TrainConfig {
54 fn default() -> Self {
55 Self {
56 dataset_path: PathBuf::from("files/enwik8"),
57 output_model_path: PathBuf::from("out/rwkv7_byte_small.safetensors"),
58 resume_model_path: None,
59 steps: 2_000,
60 batch_size: 32,
61 seq_len: 128,
62 grad_accum_steps: 1,
63 lr: 2e-4,
64 weight_decay: 0.1,
65 beta1: 0.9,
66 beta2: 0.99,
67 adam_eps: 1e-8,
68 seed: 42,
69 device: None,
70 validate_roundtrip_path: Some(PathBuf::from("files/bench.txt")),
71 model_cfg: TrainModelConfig::small_default(),
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
77pub struct TrainReport {
78 pub device: String,
79 pub steps: usize,
80 pub tokens_per_step: i64,
81 pub final_loss: f64,
82 pub steps_per_sec: f64,
83 pub tokens_per_sec: f64,
84 pub output_model_path: PathBuf,
85}
86
87struct AdamWEntry {
88 p: Tensor,
89 m: Tensor,
90 v: Tensor,
91}
92
93struct AdamW {
94 entries: Vec<AdamWEntry>,
95 beta1: f64,
96 beta2: f64,
97 eps: f64,
98 weight_decay: f64,
99 step: i64,
100}
101
102impl AdamW {
103 fn new(params: Vec<Tensor>, beta1: f64, beta2: f64, eps: f64, weight_decay: f64) -> Self {
104 let mut entries = Vec::with_capacity(params.len());
105 for p in params {
106 let m = Tensor::zeros_like(&p);
107 let v = Tensor::zeros_like(&p);
108 entries.push(AdamWEntry { p, m, v });
109 }
110 Self {
111 entries,
112 beta1,
113 beta2,
114 eps,
115 weight_decay,
116 step: 0,
117 }
118 }
119
120 fn new_with_state(
122 params: Vec<Tensor>,
123 beta1: f64,
124 beta2: f64,
125 eps: f64,
126 weight_decay: f64,
127 checkpoint_path: Option<&std::path::Path>,
128 ) -> Result<Self> {
129 let mut optimizer = Self::new(params, beta1, beta2, eps, weight_decay);
130
131 if let Some(path) = checkpoint_path {
133 if path.exists() {
134 optimizer.load_state(path)?;
135 println!("Loaded optimizer state from checkpoint");
136 } else {
137 println!(
138 "Warning: Optimizer checkpoint not found, starting with fresh optimizer state"
139 );
140 }
141 }
142
143 Ok(optimizer)
144 }
145
146 fn save_state(&self, path: &std::path::Path) -> Result<()> {
148 use std::collections::BTreeMap;
149 use std::fs;
150 use std::io::Write;
151
152 let mut tensors: BTreeMap<String, (Vec<usize>, Vec<f32>)> = BTreeMap::new();
153
154 tensors.insert(
156 "optimizer.step".to_string(),
157 (vec![1], vec![self.step as f32]),
158 );
159
160 for (i, entry) in self.entries.iter().enumerate() {
162 let prefix = format!("optimizer.{}", i);
163
164 tensors.insert(
166 format!("{}.m", prefix),
167 (tensor_shape_usize(&entry.m), tensor_to_vec_f32(&entry.m)?),
168 );
169
170 tensors.insert(
172 format!("{}.v", prefix),
173 (tensor_shape_usize(&entry.v), tensor_to_vec_f32(&entry.v)?),
174 );
175 }
176
177 let mut cursor: usize = 0;
179 let mut meta_entries: Vec<String> = Vec::with_capacity(tensors.len());
180 let mut data_blobs: Vec<Vec<u8>> = Vec::with_capacity(tensors.len());
181
182 for (name, (shape, data_f32)) in tensors.into_iter() {
183 let byte_len = data_f32.len() * 4;
184 let start = cursor;
185 let end = cursor + byte_len;
186 cursor = end;
187
188 let mut bytes = Vec::with_capacity(byte_len);
189 for v in data_f32 {
190 bytes.extend_from_slice(&v.to_le_bytes());
191 }
192 data_blobs.push(bytes);
193
194 let shape_json = shape
195 .iter()
196 .map(|d| d.to_string())
197 .collect::<Vec<_>>()
198 .join(",");
199 let entry = format!(
200 "\"{}\":{{\"dtype\":\"F32\",\"shape\":[{}],\"data_offsets\":[{},{}]}}",
201 name, shape_json, start, end
202 );
203 meta_entries.push(entry);
204 }
205
206 let header_json = format!("{{\"__metadata__\":{{}},{} }}", meta_entries.join(","));
207 let header_bytes = header_json.as_bytes();
208 let header_len = header_bytes.len() as u64;
209
210 let mut out = Vec::with_capacity(8 + header_bytes.len() + cursor);
211 out.extend_from_slice(&header_len.to_le_bytes());
212 out.extend_from_slice(header_bytes);
213 for blob in data_blobs {
214 out.extend_from_slice(&blob);
215 }
216
217 fs::write(path, out).with_context(|| {
218 format!("Failed to write optimizer checkpoint to {}", path.display())
219 })?;
220
221 Ok(())
222 }
223
224 fn load_state(&mut self, path: &std::path::Path) -> Result<()> {
226 use std::fs::File;
227 use std::io::Read;
228
229 let mut file = File::open(path)
230 .with_context(|| format!("Failed to open optimizer checkpoint: {}", path.display()))?;
231
232 let mut header_len_bytes = [0u8; 8];
234 file.read_exact(&mut header_len_bytes)?;
235 let header_len = u64::from_le_bytes(header_len_bytes) as usize;
236
237 let mut header_bytes = vec![0u8; header_len];
239 file.read_exact(&mut header_bytes)?;
240 let header_str = std::str::from_utf8(&header_bytes)
241 .context("Invalid UTF-8 in optimizer checkpoint header")?;
242
243 let json: serde_json::Value = serde_json::from_str(header_str)
245 .context("Failed to parse optimizer checkpoint JSON header")?;
246
247 let data_offset = 8 + header_len;
248
249 if let Some(step_info) = json.get("optimizer.step") {
251 let offsets = step_info["data_offsets"]
252 .as_array()
253 .ok_or_else(|| anyhow::anyhow!("Invalid data_offsets for step"))?
254 .iter()
255 .map(|v| v.as_u64().unwrap() as usize)
256 .collect::<Vec<usize>>();
257
258 let byte_len = offsets[1] - offsets[0];
259 let mut raw_bytes = vec![0u8; byte_len];
260 use std::io::Seek;
261 file.seek(std::io::SeekFrom::Start((data_offset + offsets[0]) as u64))?;
262 file.read_exact(&mut raw_bytes)?;
263
264 let step_f32 =
266 f32::from_le_bytes([raw_bytes[0], raw_bytes[1], raw_bytes[2], raw_bytes[3]]);
267 self.step = step_f32 as i64;
268 }
269
270 for (i, entry) in self.entries.iter_mut().enumerate() {
272 let prefix = format!("optimizer.{}", i);
273
274 if let Some(m_info) = json.get(&format!("{}.m", prefix)) {
276 let shape = m_info["shape"]
277 .as_array()
278 .ok_or_else(|| anyhow::anyhow!("Invalid shape for momentum"))?
279 .iter()
280 .map(|v| v.as_u64().unwrap() as i64)
281 .collect::<Vec<i64>>();
282
283 let offsets = m_info["data_offsets"]
284 .as_array()
285 .ok_or_else(|| anyhow::anyhow!("Invalid data_offsets for momentum"))?
286 .iter()
287 .map(|v| v.as_u64().unwrap() as usize)
288 .collect::<Vec<usize>>();
289
290 let byte_len = offsets[1] - offsets[0];
291 let mut raw_bytes = vec![0u8; byte_len];
292 use std::io::Seek;
293 file.seek(std::io::SeekFrom::Start((data_offset + offsets[0]) as u64))?;
294 file.read_exact(&mut raw_bytes)?;
295
296 let mut data = vec![0f32; byte_len / 4];
297 for j in 0..data.len() {
298 let base = j * 4;
299 data[j] = f32::from_le_bytes([
300 raw_bytes[base],
301 raw_bytes[base + 1],
302 raw_bytes[base + 2],
303 raw_bytes[base + 3],
304 ]);
305 }
306
307 entry.m = Tensor::f_from_slice(&data)?
308 .view(&*shape)
309 .to_device(entry.m.device());
310 }
311
312 if let Some(v_info) = json.get(&format!("{}.v", prefix)) {
314 let shape = v_info["shape"]
315 .as_array()
316 .ok_or_else(|| anyhow::anyhow!("Invalid shape for variance"))?
317 .iter()
318 .map(|v| v.as_u64().unwrap() as i64)
319 .collect::<Vec<i64>>();
320
321 let offsets = v_info["data_offsets"]
322 .as_array()
323 .ok_or_else(|| anyhow::anyhow!("Invalid data_offsets for variance"))?
324 .iter()
325 .map(|v| v.as_u64().unwrap() as usize)
326 .collect::<Vec<usize>>();
327
328 let byte_len = offsets[1] - offsets[0];
329 let mut raw_bytes = vec![0u8; byte_len];
330 use std::io::Seek;
331 file.seek(std::io::SeekFrom::Start((data_offset + offsets[0]) as u64))?;
332 file.read_exact(&mut raw_bytes)?;
333
334 let mut data = vec![0f32; byte_len / 4];
335 for j in 0..data.len() {
336 let base = j * 4;
337 data[j] = f32::from_le_bytes([
338 raw_bytes[base],
339 raw_bytes[base + 1],
340 raw_bytes[base + 2],
341 raw_bytes[base + 3],
342 ]);
343 }
344
345 entry.v = Tensor::f_from_slice(&data)?
346 .view(&*shape)
347 .to_device(entry.v.device());
348 }
349 }
350
351 Ok(())
352 }
353
354 fn zero_grad(&mut self) {
355 for e in &mut self.entries {
356 e.p.zero_grad();
357 }
358 }
359
360 fn step(&mut self, lr: f64) {
361 self.step += 1;
362 let t = self.step as f64;
363 let bias_c1 = 1.0 - self.beta1.powf(t);
364 let bias_c2 = 1.0 - self.beta2.powf(t);
365 let step_size = lr * (bias_c2.sqrt() / bias_c1);
366
367 no_grad(|| {
368 for e in &mut self.entries {
369 let g = e.p.grad();
370 if !g.defined() {
371 continue;
372 }
373
374 e.m = &e.m * self.beta1 + &g * (1.0 - self.beta1);
376 e.v = &e.v * self.beta2 + g.square() * (1.0 - self.beta2);
378
379 let denom = e.v.sqrt() + self.eps;
381
382 if self.weight_decay != 0.0 {
384 let wd = &e.p * (-lr * self.weight_decay);
385 let _ = e.p.f_add_(&wd);
387 }
388
389 let update = (&e.m / denom) * (-step_size);
391 let _ = e.p.f_add_(&update);
392 }
393 });
394
395 self.zero_grad();
396 }
397}
398
399fn pick_device(requested: &Option<String>) -> Result<Device> {
400 let cuda_available = tch::Cuda::is_available();
401
402 if let Some(s) = requested.as_deref() {
403 match s {
404 "cpu" => return Ok(Device::Cpu),
405 "cuda" | "gpu" => {
406 if cuda_available {
407 return Ok(Device::Cuda(0));
408 }
409 bail!("CUDA was requested but is not available. Set TORCH_CUDA_VERSION=121 (or 118) and rebuild, or pass --device cpu to force CPU.");
410 }
411 other => bail!("Unknown device '{}'. Use cpu or cuda", other),
412 }
413 }
414
415 if cuda_available {
416 Ok(Device::Cuda(0))
417 } else {
418 bail!(
419 "CUDA not available. Training expects a CUDA-enabled libtorch. Install/rebuild with TORCH_CUDA_VERSION=121 (or 118) or rerun with --device cpu explicitly."
420 )
421 }
422}
423
424pub fn train_enwik8(cfg: TrainConfig) -> Result<TrainReport> {
425 cfg.model_cfg.validate()?;
426
427 let device = pick_device(&cfg.device)?;
428 let device_str = match device {
429 Device::Cpu => "cpu".to_string(),
430 Device::Cuda(_) => {
431 tch::Cuda::cudnn_set_benchmark(true);
433 "cuda:0".to_string()
434 }
435 _ => format!("{:?}", device),
436 };
437
438 if let Some(parent) = cfg.output_model_path.parent() {
440 std::fs::create_dir_all(parent)
441 .with_context(|| format!("create output dir {}", parent.display()))?;
442 }
443
444 let ds = Enwik8Mmap::open(&cfg.dataset_path)?;
445 let mut rng = Enwik8Mmap::seeded_rng(cfg.seed);
446
447 let model = if let Some(resume_path) = &cfg.resume_model_path {
449 println!("Loading model weights from {}...", resume_path.display());
450 FastTrainModel::load_from_safetensors(resume_path, cfg.model_cfg.clone(), device)
451 .with_context(|| format!("Failed to load model from {}", resume_path.display()))?
452 } else {
453 FastTrainModel::new(cfg.model_cfg.clone(), device, cfg.seed as i64)?
454 };
455
456 let mut state = FastTrainState::new(&cfg.model_cfg, cfg.batch_size, device)?;
457
458 let params = model.parameters();
459
460 let optimizer_checkpoint_path = cfg.output_model_path.with_extension("opt.safetensors");
462
463 let resume_optimizer_path = cfg
465 .resume_model_path
466 .as_ref()
467 .map(|p| p.with_extension("opt.safetensors"));
468
469 let mut opt = AdamW::new_with_state(
470 params,
471 cfg.beta1,
472 cfg.beta2,
473 cfg.adam_eps,
474 cfg.weight_decay,
475 resume_optimizer_path.as_deref(),
476 )?;
477
478 let accum_steps = cfg.grad_accum_steps.max(1);
479 let tokens_per_step = cfg.batch_size * cfg.seq_len * (accum_steps as i64);
480 let start = Instant::now();
481 let mut last_loss = 0.0f64;
482
483 for step in 0..cfg.steps {
484 let mut step_loss_sum = 0.0f64;
485
486 for micro_step in 0..accum_steps {
487 state.reset();
488
489 let (x, y) = ds.sample_batch(&mut rng, cfg.batch_size, cfg.seq_len, device)?;
490
491 let logits = model.forward_sequence(&x, &mut state)?; let targets = y; let loss = logits
497 .view([-1, cfg.model_cfg.vocab_size])
498 .cross_entropy_for_logits(&targets.view([-1]))
499 / (accum_steps as f64);
500
501 loss.backward();
502
503 step_loss_sum += loss.to_device(Device::Cpu).double_value(&[]) * (accum_steps as f64);
505
506 if micro_step == accum_steps - 1 {
508 opt.step(cfg.lr);
509 }
510 }
511
512 last_loss = step_loss_sum / (accum_steps as f64);
513
514 if step % 50 == 0 {
515 let elapsed = start.elapsed().as_secs_f64().max(1e-9);
516 let sps = (step.max(1) as f64) / elapsed;
517 let tps = sps * (tokens_per_step as f64);
518 eprintln!(
519 "step {:6} | loss {:8.4} | {:8.2} steps/s | {:10.0} tok/s | {}",
520 step, last_loss, sps, tps, device_str
521 );
522
523 if let Err(e) = opt.save_state(&optimizer_checkpoint_path) {
525 eprintln!("Warning: Failed to save optimizer checkpoint: {}", e);
526 }
527 }
528 }
529
530 export_safetensors(&cfg.output_model_path, &cfg.model_cfg, &model.p)
532 .context("export safetensors")?;
533
534 if let Err(e) = opt.save_state(&optimizer_checkpoint_path) {
536 eprintln!("Warning: Failed to save final optimizer checkpoint: {}", e);
537 }
538
539 if let Some(p) = &cfg.validate_roundtrip_path {
541 validate_roundtrip(&cfg.output_model_path, p)
542 .with_context(|| format!("validate roundtrip using {}", p.display()))?;
543 }
544
545 let elapsed = start.elapsed().as_secs_f64().max(1e-9);
546 let steps_per_sec = (cfg.steps as f64) / elapsed;
547 let tokens_per_sec = steps_per_sec * (tokens_per_step as f64);
548
549 Ok(TrainReport {
550 device: device_str,
551 steps: cfg.steps,
552 tokens_per_step,
553 final_loss: last_loss,
554 steps_per_sec,
555 tokens_per_sec,
556 output_model_path: cfg.output_model_path,
557 })
558}