vmm/devices/virtio/pmem/
device.rs

1// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fs::{File, OpenOptions};
5use std::ops::{Deref, DerefMut};
6use std::os::fd::AsRawFd;
7use std::sync::{Arc, Mutex};
8
9use kvm_bindings::{KVM_MEM_READONLY, kvm_userspace_memory_region};
10use kvm_ioctls::VmFd;
11use serde::{Deserialize, Serialize};
12use vm_allocator::AllocPolicy;
13use vm_memory::mmap::{MmapRegionBuilder, MmapRegionError};
14use vm_memory::{GuestAddress, GuestMemoryError};
15use vmm_sys_util::eventfd::EventFd;
16
17use crate::devices::virtio::ActivateError;
18use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice};
19use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1;
20use crate::devices::virtio::generated::virtio_ids::VIRTIO_ID_PMEM;
21use crate::devices::virtio::pmem::PMEM_QUEUE_SIZE;
22use crate::devices::virtio::pmem::metrics::{PmemMetrics, PmemMetricsPerDevice};
23use crate::devices::virtio::queue::{DescriptorChain, InvalidAvailIdx, Queue, QueueError};
24use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType};
25use crate::logger::{IncMetric, error, info};
26use crate::utils::{align_up, u64_to_usize};
27use crate::vmm_config::pmem::PmemConfig;
28use crate::vstate::memory::{ByteValued, Bytes, GuestMemoryMmap, GuestMmapRegion};
29use crate::vstate::vm::VmError;
30use crate::{Vm, impl_device_type};
31
32#[derive(Debug, thiserror::Error, displaydoc::Display)]
33pub enum PmemError {
34    /// Cannot set the memory regions: {0}
35    SetUserMemoryRegion(VmError),
36    /// Unablet to allocate a KVM slot for the device
37    NoKvmSlotAvailable,
38    /// Error accessing backing file: {0}
39    BackingFile(std::io::Error),
40    /// Error backing file size is 0
41    BackingFileZeroSize,
42    /// Error with EventFd: {0}
43    EventFd(std::io::Error),
44    /// Unexpected read-only descriptor
45    ReadOnlyDescriptor,
46    /// Unexpected write-only descriptor
47    WriteOnlyDescriptor,
48    /// UnknownRequestType: {0}
49    UnknownRequestType(u32),
50    /// Descriptor chain too short
51    DescriptorChainTooShort,
52    /// Guest memory error: {0}
53    GuestMemory(#[from] GuestMemoryError),
54    /// Error handling the VirtIO queue: {0}
55    Queue(#[from] QueueError),
56    /// Error during obtaining the descriptor from the queue: {0}
57    QueuePop(#[from] InvalidAvailIdx),
58}
59
60const VIRTIO_PMEM_REQ_TYPE_FLUSH: u32 = 0;
61const SUCCESS: i32 = 0;
62const FAILURE: i32 = -1;
63
64#[derive(Debug, Default, Copy, Clone, Serialize, Deserialize)]
65#[repr(C)]
66pub struct ConfigSpace {
67    // Physical address of the first byte of the persistent memory region.
68    pub start: u64,
69    // Length of the address range
70    pub size: u64,
71}
72
73// SAFETY: `ConfigSpace` contains only PODs in `repr(c)`, without padding.
74unsafe impl ByteValued for ConfigSpace {}
75
76#[derive(Debug)]
77pub struct Pmem {
78    // VirtIO fields
79    pub avail_features: u64,
80    pub acked_features: u64,
81    pub activate_event: EventFd,
82
83    // Transport fields
84    pub device_state: DeviceState,
85    pub queues: Vec<Queue>,
86    pub queue_events: Vec<EventFd>,
87
88    // Pmem specific fields
89    pub config_space: ConfigSpace,
90    pub file: File,
91    pub file_len: u64,
92    pub mmap_ptr: u64,
93    pub metrics: Arc<PmemMetrics>,
94
95    pub config: PmemConfig,
96}
97
98impl Drop for Pmem {
99    fn drop(&mut self) {
100        let mmap_len = align_up(self.file_len, Self::ALIGNMENT);
101        // SAFETY: `mmap_ptr` is a valid pointer since Pmem can only be created with `new*` methods.
102        //         Mapping size calculation is same for original mmap call.
103        unsafe {
104            _ = libc::munmap(self.mmap_ptr as *mut libc::c_void, u64_to_usize(mmap_len));
105        }
106    }
107}
108
109impl Pmem {
110    // Pmem devices need to have address and size to be
111    // a multiple of 2MB
112    pub const ALIGNMENT: u64 = 2 * 1024 * 1024;
113
114    /// Create a new Pmem device with a backing file at `disk_image_path` path.
115    pub fn new(config: PmemConfig) -> Result<Self, PmemError> {
116        Self::new_with_queues(config, vec![Queue::new(PMEM_QUEUE_SIZE)])
117    }
118
119    /// Create a new Pmem device with a backing file at `disk_image_path` path using a pre-created
120    /// set of queues.
121    pub fn new_with_queues(config: PmemConfig, queues: Vec<Queue>) -> Result<Self, PmemError> {
122        let (file, file_len, mmap_ptr, mmap_len) =
123            Self::mmap_backing_file(&config.path_on_host, config.read_only)?;
124
125        Ok(Self {
126            avail_features: 1u64 << VIRTIO_F_VERSION_1,
127            acked_features: 0u64,
128            activate_event: EventFd::new(libc::EFD_NONBLOCK).map_err(PmemError::EventFd)?,
129            device_state: DeviceState::Inactive,
130            queues,
131            queue_events: vec![EventFd::new(libc::EFD_NONBLOCK).map_err(PmemError::EventFd)?],
132            config_space: ConfigSpace {
133                start: 0,
134                size: mmap_len,
135            },
136            file,
137            file_len,
138            mmap_ptr,
139            metrics: PmemMetricsPerDevice::alloc(config.id.clone()),
140            config,
141        })
142    }
143
144    fn mmap_backing_file(path: &str, read_only: bool) -> Result<(File, u64, u64, u64), PmemError> {
145        let file = OpenOptions::new()
146            .read(true)
147            .write(!read_only)
148            .open(path)
149            .map_err(PmemError::BackingFile)?;
150        let file_len = file.metadata().unwrap().len();
151        if (file_len == 0) {
152            return Err(PmemError::BackingFileZeroSize);
153        }
154
155        let mut prot = libc::PROT_READ;
156        if !read_only {
157            prot |= libc::PROT_WRITE;
158        }
159
160        let mmap_len = align_up(file_len, Self::ALIGNMENT);
161        let mmap_ptr = if (mmap_len == file_len) {
162            // SAFETY: We are calling the system call with valid arguments and checking the returned
163            // value
164            unsafe {
165                let r = libc::mmap(
166                    std::ptr::null_mut(),
167                    u64_to_usize(file_len),
168                    prot,
169                    libc::MAP_SHARED | libc::MAP_NORESERVE,
170                    file.as_raw_fd(),
171                    0,
172                );
173                if r == libc::MAP_FAILED {
174                    return Err(PmemError::BackingFile(std::io::Error::last_os_error()));
175                }
176                r
177            }
178        } else {
179            // SAFETY: We are calling system calls with valid arguments and checking returned
180            // values
181            //
182            // The double mapping is done to ensure the underlying memory has the size of
183            // `mmap_len` (wich is 2MB aligned as per `virtio-pmem` specification)
184            // First mmap creates a mapping of `mmap_len` while second mmaps the actual
185            // file on top. The remaining gap between the end of the mmaped file and
186            // the actual end of the memory region is backed by PRIVATE | ANONYMOUS memory.
187            unsafe {
188                let mmap_ptr = libc::mmap(
189                    std::ptr::null_mut(),
190                    u64_to_usize(mmap_len),
191                    prot,
192                    libc::MAP_PRIVATE | libc::MAP_NORESERVE | libc::MAP_ANONYMOUS,
193                    -1,
194                    0,
195                );
196                if mmap_ptr == libc::MAP_FAILED {
197                    return Err(PmemError::BackingFile(std::io::Error::last_os_error()));
198                }
199                let r = libc::mmap(
200                    mmap_ptr,
201                    u64_to_usize(file_len),
202                    prot,
203                    libc::MAP_SHARED | libc::MAP_NORESERVE | libc::MAP_FIXED,
204                    file.as_raw_fd(),
205                    0,
206                );
207                if r == libc::MAP_FAILED {
208                    return Err(PmemError::BackingFile(std::io::Error::last_os_error()));
209                }
210                mmap_ptr
211            }
212        };
213        Ok((file, file_len, mmap_ptr as u64, mmap_len))
214    }
215
216    /// Allocate memory in past_mmio64 memory region
217    pub fn alloc_region(&mut self, vm: &Vm) {
218        let mut resource_allocator_lock = vm.resource_allocator();
219        let resource_allocator = resource_allocator_lock.deref_mut();
220        let addr = resource_allocator
221            .past_mmio64_memory
222            .allocate(
223                self.config_space.size,
224                Pmem::ALIGNMENT,
225                AllocPolicy::FirstMatch,
226            )
227            .unwrap();
228        self.config_space.start = addr.start();
229    }
230
231    /// Set user memory region in KVM
232    pub fn set_mem_region(&mut self, vm: &Vm) -> Result<(), PmemError> {
233        let next_slot = vm.next_kvm_slot(1).ok_or(PmemError::NoKvmSlotAvailable)?;
234        let memory_region = kvm_userspace_memory_region {
235            slot: next_slot,
236            guest_phys_addr: self.config_space.start,
237            memory_size: self.config_space.size,
238            userspace_addr: self.mmap_ptr,
239            flags: if self.config.read_only {
240                KVM_MEM_READONLY
241            } else {
242                0
243            },
244        };
245
246        vm.set_user_memory_region(memory_region)
247            .map_err(PmemError::SetUserMemoryRegion)
248    }
249
250    pub fn handle_queue(&mut self) -> Result<(), PmemError> {
251        // This is safe since we checked in the event handler that the device is activated.
252        let active_state = self.device_state.active_state().unwrap();
253
254        while let Some(head) = self.queues[0].pop()? {
255            let add_result = match self.process_chain(head) {
256                Ok(()) => self.queues[0].add_used(head.index, 4),
257                Err(err) => {
258                    error!("pmem: {err}");
259                    self.metrics.event_fails.inc();
260                    self.queues[0].add_used(head.index, 0)
261                }
262            };
263            if let Err(err) = add_result {
264                error!("pmem: {err}");
265                self.metrics.event_fails.inc();
266                break;
267            }
268        }
269        self.queues[0].advance_used_ring_idx();
270
271        if self.queues[0].prepare_kick() {
272            active_state
273                .interrupt
274                .trigger(VirtioInterruptType::Queue(0))
275                .unwrap_or_else(|err| {
276                    error!("pmem: {err}");
277                    self.metrics.event_fails.inc();
278                });
279        }
280        Ok(())
281    }
282
283    fn process_chain(&self, head: DescriptorChain) -> Result<(), PmemError> {
284        // This is safe since we checked in the event handler that the device is activated.
285        let active_state = self.device_state.active_state().unwrap();
286
287        if head.is_write_only() {
288            return Err(PmemError::WriteOnlyDescriptor);
289        }
290        let request: u32 = active_state.mem.read_obj(head.addr)?;
291        if request != VIRTIO_PMEM_REQ_TYPE_FLUSH {
292            return Err(PmemError::UnknownRequestType(request));
293        }
294        let Some(status_descriptor) = head.next_descriptor() else {
295            return Err(PmemError::DescriptorChainTooShort);
296        };
297        if !status_descriptor.is_write_only() {
298            return Err(PmemError::ReadOnlyDescriptor);
299        }
300        let mut result = SUCCESS;
301        // SAFETY: We are calling the system call with valid arguments and checking the returned
302        // value
303        unsafe {
304            let ret = libc::msync(
305                self.mmap_ptr as *mut libc::c_void,
306                u64_to_usize(self.file_len),
307                libc::MS_SYNC,
308            );
309            if ret < 0 {
310                error!("pmem: Unable to msync the file. Error: {}", ret);
311                result = FAILURE;
312            }
313        }
314        active_state.mem.write_obj(result, status_descriptor.addr)?;
315        Ok(())
316    }
317
318    pub fn process_queue(&mut self) {
319        self.metrics.queue_event_count.inc();
320        if let Err(err) = self.queue_events[0].read() {
321            error!("pmem: Failed to get queue event: {err:?}");
322            self.metrics.event_fails.inc();
323            return;
324        }
325
326        self.handle_queue().unwrap_or_else(|err| {
327            error!("pmem: {err:?}");
328            self.metrics.event_fails.inc();
329        });
330    }
331}
332
333impl VirtioDevice for Pmem {
334    impl_device_type!(VIRTIO_ID_PMEM);
335
336    fn avail_features(&self) -> u64 {
337        self.avail_features
338    }
339
340    fn acked_features(&self) -> u64 {
341        self.acked_features
342    }
343
344    fn set_acked_features(&mut self, acked_features: u64) {
345        self.acked_features = acked_features;
346    }
347
348    fn queues(&self) -> &[Queue] {
349        &self.queues
350    }
351
352    fn queues_mut(&mut self) -> &mut [Queue] {
353        &mut self.queues
354    }
355
356    fn queue_events(&self) -> &[EventFd] {
357        &self.queue_events
358    }
359
360    fn interrupt_trigger(&self) -> &dyn VirtioInterrupt {
361        self.device_state
362            .active_state()
363            .expect("Device not activated")
364            .interrupt
365            .deref()
366    }
367
368    fn read_config(&self, offset: u64, data: &mut [u8]) {
369        if let Some(config_space_bytes) = self.config_space.as_slice().get(u64_to_usize(offset)..) {
370            let len = config_space_bytes.len().min(data.len());
371            data[..len].copy_from_slice(&config_space_bytes[..len]);
372        } else {
373            error!("Failed to read config space");
374            self.metrics.cfg_fails.inc();
375        }
376    }
377
378    fn write_config(&mut self, _offset: u64, _data: &[u8]) {}
379
380    fn activate(
381        &mut self,
382        mem: GuestMemoryMmap,
383        interrupt: Arc<dyn VirtioInterrupt>,
384    ) -> Result<(), ActivateError> {
385        for q in self.queues.iter_mut() {
386            q.initialize(&mem)
387                .map_err(ActivateError::QueueMemoryError)?;
388        }
389
390        if self.activate_event.write(1).is_err() {
391            self.metrics.activate_fails.inc();
392            return Err(ActivateError::EventFd);
393        }
394        self.device_state = DeviceState::Activated(ActiveState { mem, interrupt });
395        Ok(())
396    }
397
398    fn is_activated(&self) -> bool {
399        self.device_state.is_activated()
400    }
401
402    fn kick(&mut self) {
403        if self.is_activated() {
404            info!("kick pmem {}.", self.config.id);
405            self.handle_queue();
406        }
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use vmm_sys_util::tempfile::TempFile;
413
414    use super::*;
415    use crate::devices::virtio::queue::{VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE};
416    use crate::devices::virtio::test_utils::{VirtQueue, default_interrupt, default_mem};
417
418    #[test]
419    fn test_from_config() {
420        let config = PmemConfig {
421            id: "1".into(),
422            path_on_host: "not_a_path".into(),
423            root_device: true,
424            read_only: false,
425        };
426        assert!(matches!(
427            Pmem::new(config).unwrap_err(),
428            PmemError::BackingFile(_),
429        ));
430
431        let dummy_file = TempFile::new().unwrap();
432        let dummy_path = dummy_file.as_path().to_str().unwrap().to_string();
433        let config = PmemConfig {
434            id: "1".into(),
435            path_on_host: dummy_path.clone(),
436            root_device: true,
437            read_only: false,
438        };
439        assert!(matches!(
440            Pmem::new(config).unwrap_err(),
441            PmemError::BackingFileZeroSize,
442        ));
443
444        dummy_file.as_file().set_len(0x20_0000);
445        let config = PmemConfig {
446            id: "1".into(),
447            path_on_host: dummy_path,
448            root_device: true,
449            read_only: false,
450        };
451        Pmem::new(config).unwrap();
452    }
453
454    #[test]
455    fn test_process_chain() {
456        let dummy_file = TempFile::new().unwrap();
457        dummy_file.as_file().set_len(0x20_0000);
458        let dummy_path = dummy_file.as_path().to_str().unwrap().to_string();
459        let config = PmemConfig {
460            id: "1".into(),
461            path_on_host: dummy_path,
462            root_device: true,
463            read_only: false,
464        };
465        let mut pmem = Pmem::new(config).unwrap();
466
467        let mem = default_mem();
468        let interrupt = default_interrupt();
469        let vq = VirtQueue::new(GuestAddress(0), &mem, 16);
470        pmem.queues[0] = vq.create_queue();
471        pmem.activate(mem.clone(), interrupt).unwrap();
472
473        // Valid request
474        {
475            vq.avail.ring[0].set(0);
476            vq.dtable[0].set(0x1000, 4, VIRTQ_DESC_F_NEXT, 1);
477            vq.avail.ring[1].set(1);
478            vq.dtable[1].set(0x2000, 4, VIRTQ_DESC_F_WRITE, 0);
479            mem.write_obj::<u32>(0, GuestAddress(0x1000)).unwrap();
480            mem.write_obj::<u32>(0x69, GuestAddress(0x2000)).unwrap();
481
482            vq.used.idx.set(0);
483            vq.avail.idx.set(1);
484            let head = pmem.queues[0].pop().unwrap().unwrap();
485            pmem.process_chain(head).unwrap();
486            assert_eq!(mem.read_obj::<u32>(GuestAddress(0x2000)).unwrap(), 0);
487        }
488
489        // Invalid request type
490        {
491            vq.avail.ring[0].set(0);
492            vq.dtable[0].set(0x1000, 4, VIRTQ_DESC_F_NEXT, 1);
493            mem.write_obj::<u32>(0x69, GuestAddress(0x1000)).unwrap();
494
495            pmem.queues[0] = vq.create_queue();
496            vq.used.idx.set(0);
497            vq.avail.idx.set(1);
498            let head = pmem.queues[0].pop().unwrap().unwrap();
499            assert!(matches!(
500                pmem.process_chain(head).unwrap_err(),
501                PmemError::UnknownRequestType(0x69),
502            ));
503        }
504
505        // Short chain request
506        {
507            vq.avail.ring[0].set(0);
508            vq.dtable[0].set(0x1000, 4, 0, 1);
509            mem.write_obj::<u32>(0, GuestAddress(0x1000)).unwrap();
510
511            pmem.queues[0] = vq.create_queue();
512            vq.used.idx.set(0);
513            vq.avail.idx.set(1);
514            let head = pmem.queues[0].pop().unwrap().unwrap();
515            assert!(matches!(
516                pmem.process_chain(head).unwrap_err(),
517                PmemError::DescriptorChainTooShort,
518            ));
519        }
520
521        // Write only first descriptor
522        {
523            vq.avail.ring[0].set(0);
524            vq.dtable[0].set(0x1000, 4, VIRTQ_DESC_F_WRITE | VIRTQ_DESC_F_NEXT, 1);
525            vq.avail.ring[1].set(1);
526            vq.dtable[1].set(0x2000, 4, VIRTQ_DESC_F_WRITE, 0);
527            mem.write_obj::<u32>(0, GuestAddress(0x1000)).unwrap();
528
529            pmem.queues[0] = vq.create_queue();
530            vq.used.idx.set(0);
531            vq.avail.idx.set(1);
532            let head = pmem.queues[0].pop().unwrap().unwrap();
533            assert!(matches!(
534                pmem.process_chain(head).unwrap_err(),
535                PmemError::WriteOnlyDescriptor,
536            ));
537        }
538
539        // Read only second descriptor
540        {
541            vq.avail.ring[0].set(0);
542            vq.dtable[0].set(0x1000, 4, VIRTQ_DESC_F_NEXT, 1);
543            vq.avail.ring[1].set(1);
544            vq.dtable[1].set(0x2000, 4, 0, 0);
545            mem.write_obj::<u32>(0, GuestAddress(0x1000)).unwrap();
546
547            pmem.queues[0] = vq.create_queue();
548            vq.used.idx.set(0);
549            vq.avail.idx.set(1);
550            let head = pmem.queues[0].pop().unwrap().unwrap();
551            assert!(matches!(
552                pmem.process_chain(head).unwrap_err(),
553                PmemError::ReadOnlyDescriptor,
554            ));
555        }
556    }
557}