1use std::fs::File;
5use std::os::unix::fs::MetadataExt;
6
7use vm_memory::{GuestAddress, GuestMemory, ReadVolatile, VolatileMemoryError};
8
9use crate::arch::initrd_load_addr;
10use crate::utils::u64_to_usize;
11use crate::vmm_config::boot_source::BootConfig;
12use crate::vstate::memory::GuestMemoryMmap;
13
14#[derive(Debug, thiserror::Error, displaydoc::Display)]
16pub enum InitrdError {
17 Address,
19 Load,
21 Metadata(std::io::Error),
23 CloneFd(std::io::Error),
25 Read(VolatileMemoryError),
27}
28
29#[derive(Debug)]
31pub struct InitrdConfig {
32 pub address: GuestAddress,
34 pub size: usize,
36}
37
38impl InitrdConfig {
39 pub fn from_config(
41 boot_cfg: &BootConfig,
42 vm_memory: &GuestMemoryMmap,
43 ) -> Result<Option<Self>, InitrdError> {
44 Ok(match &boot_cfg.initrd_file {
45 Some(f) => {
46 let f = f.try_clone().map_err(InitrdError::CloneFd)?;
47 Some(Self::from_file(vm_memory, f)?)
48 }
49 None => None,
50 })
51 }
52
53 pub fn from_file(vm_memory: &GuestMemoryMmap, mut file: File) -> Result<Self, InitrdError> {
55 let size = file.metadata().map_err(InitrdError::Metadata)?.size();
56 let size = u64_to_usize(size);
57 let Some(address) = initrd_load_addr(vm_memory, size) else {
58 return Err(InitrdError::Address);
59 };
60 let mut slice = vm_memory
61 .get_slice(GuestAddress(address), size)
62 .map_err(|_| InitrdError::Load)?;
63 file.read_exact_volatile(&mut slice)
64 .map_err(InitrdError::Read)?;
65
66 Ok(InitrdConfig {
67 address: GuestAddress(address),
68 size,
69 })
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use std::io::{Seek, SeekFrom, Write};
76
77 use vmm_sys_util::tempfile::TempFile;
78
79 use super::*;
80 use crate::arch::GUEST_PAGE_SIZE;
81 use crate::test_utils::{single_region_mem, single_region_mem_at};
82
83 fn make_test_bin() -> Vec<u8> {
84 let mut fake_bin = Vec::new();
85 fake_bin.resize(1_000_000, 0xAA);
86 fake_bin
87 }
88
89 #[test]
90 fn test_load_initrd() {
92 let image = make_test_bin();
93
94 let mem_size: usize = image.len() * 2 + GUEST_PAGE_SIZE;
95
96 let tempfile = TempFile::new().unwrap();
97 let mut tempfile = tempfile.into_file();
98 tempfile.write_all(&image).unwrap();
99
100 #[cfg(target_arch = "x86_64")]
101 let gm = single_region_mem(mem_size);
102
103 #[cfg(target_arch = "aarch64")]
104 let gm = single_region_mem(mem_size + crate::arch::aarch64::layout::FDT_MAX_SIZE);
105
106 tempfile.seek(SeekFrom::Start(0)).unwrap();
108 let initrd = InitrdConfig::from_file(&gm, tempfile).unwrap();
109 assert!(gm.address_in_range(initrd.address));
110 assert_eq!(initrd.size, image.len());
111 }
112
113 #[test]
114 fn test_load_initrd_no_memory() {
115 let gm = single_region_mem(79);
116 let image = make_test_bin();
117 let tempfile = TempFile::new().unwrap();
118 let mut tempfile = tempfile.into_file();
119 tempfile.write_all(&image).unwrap();
120
121 tempfile.seek(SeekFrom::Start(0)).unwrap();
123 let res = InitrdConfig::from_file(&gm, tempfile);
124 assert!(matches!(res, Err(InitrdError::Address)), "{:?}", res);
125 }
126
127 #[test]
128 fn test_load_initrd_unaligned() {
129 let image = vec![1, 2, 3, 4];
130 let tempfile = TempFile::new().unwrap();
131 let mut tempfile = tempfile.into_file();
132 tempfile.write_all(&image).unwrap();
133 let gm = single_region_mem_at(GUEST_PAGE_SIZE as u64 + 1, image.len() * 2);
134
135 tempfile.seek(SeekFrom::Start(0)).unwrap();
137 let res = InitrdConfig::from_file(&gm, tempfile);
138 assert!(matches!(res, Err(InitrdError::Address)), "{:?}", res);
139 }
140}