vmm/devices/virtio/
persist.rs

1// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Defines the structures needed for saving/restoring Virtio primitives.
5
6use std::num::Wrapping;
7use std::sync::atomic::Ordering;
8use std::sync::{Arc, Mutex};
9
10use serde::{Deserialize, Serialize};
11
12use super::queue::{InvalidAvailIdx, QueueError};
13use super::transport::mmio::IrqTrigger;
14use crate::devices::virtio::device::VirtioDevice;
15use crate::devices::virtio::generated::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
16use crate::devices::virtio::queue::Queue;
17use crate::devices::virtio::transport::mmio::MmioTransport;
18use crate::snapshot::Persist;
19use crate::vstate::memory::{GuestAddress, GuestMemoryMmap};
20
21/// Errors thrown during restoring virtio state.
22#[derive(Debug, thiserror::Error, displaydoc::Display)]
23pub enum PersistError {
24    /// Snapshot state contains invalid queue info.
25    InvalidInput,
26    /// Could not restore queue: {0}
27    QueueConstruction(QueueError),
28    /// {0}
29    InvalidAvailIdx(#[from] InvalidAvailIdx),
30}
31
32/// Queue information saved in snapshot.
33#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
34pub struct QueueState {
35    /// The maximal size in elements offered by the device
36    max_size: u16,
37
38    /// The queue size in elements the driver selected
39    size: u16,
40
41    /// Indicates if the queue is finished with configuration
42    ready: bool,
43
44    /// Guest physical address of the descriptor table
45    desc_table: u64,
46
47    /// Guest physical address of the available ring
48    avail_ring: u64,
49
50    /// Guest physical address of the used ring
51    used_ring: u64,
52
53    next_avail: Wrapping<u16>,
54    next_used: Wrapping<u16>,
55
56    /// The number of added used buffers since last guest kick
57    num_added: Wrapping<u16>,
58}
59
60/// Auxiliary structure for restoring queues.
61#[derive(Debug, Clone)]
62pub struct QueueConstructorArgs {
63    /// Pointer to guest memory.
64    pub mem: GuestMemoryMmap,
65    /// Is device this queue belong to activated
66    pub is_activated: bool,
67}
68
69impl Persist<'_> for Queue {
70    type State = QueueState;
71    type ConstructorArgs = QueueConstructorArgs;
72    type Error = QueueError;
73
74    fn save(&self) -> Self::State {
75        QueueState {
76            max_size: self.max_size,
77            size: self.size,
78            ready: self.ready,
79            desc_table: self.desc_table_address.0,
80            avail_ring: self.avail_ring_address.0,
81            used_ring: self.used_ring_address.0,
82            next_avail: self.next_avail,
83            next_used: self.next_used,
84            num_added: self.num_added,
85        }
86    }
87
88    fn restore(
89        constructor_args: Self::ConstructorArgs,
90        state: &Self::State,
91    ) -> Result<Self, Self::Error> {
92        let mut queue = Queue {
93            max_size: state.max_size,
94            size: state.size,
95            ready: state.ready,
96            desc_table_address: GuestAddress(state.desc_table),
97            avail_ring_address: GuestAddress(state.avail_ring),
98            used_ring_address: GuestAddress(state.used_ring),
99
100            desc_table_ptr: std::ptr::null(),
101            avail_ring_ptr: std::ptr::null_mut(),
102            used_ring_ptr: std::ptr::null_mut(),
103
104            next_avail: state.next_avail,
105            next_used: state.next_used,
106            uses_notif_suppression: false,
107            num_added: state.num_added,
108        };
109        if constructor_args.is_activated {
110            queue.initialize(&constructor_args.mem)?;
111        }
112        Ok(queue)
113    }
114}
115
116/// State of a VirtioDevice.
117#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
118pub struct VirtioDeviceState {
119    /// Device type.
120    pub device_type: u32,
121    /// Available virtio features.
122    pub avail_features: u64,
123    /// Negotiated virtio features.
124    pub acked_features: u64,
125    /// Interrupt status register.
126    pub interrupt_status: u32,
127    /// List of queues.
128    pub queues: Vec<QueueState>,
129    /// Flag for activated status.
130    pub activated: bool,
131}
132
133impl VirtioDeviceState {
134    /// Construct the virtio state of a device.
135    pub fn from_device(device: &dyn VirtioDevice) -> Self {
136        VirtioDeviceState {
137            device_type: device.device_type(),
138            avail_features: device.avail_features(),
139            acked_features: device.acked_features(),
140            interrupt_status: device.interrupt_status().load(Ordering::Relaxed),
141            queues: device.queues().iter().map(Persist::save).collect(),
142            activated: device.is_activated(),
143        }
144    }
145
146    /// Does sanity checking on the `self` state against expected values
147    /// and builds queues from state.
148    pub fn build_queues_checked(
149        &self,
150        mem: &GuestMemoryMmap,
151        expected_device_type: u32,
152        expected_num_queues: usize,
153        expected_queue_max_size: u16,
154    ) -> Result<Vec<Queue>, PersistError> {
155        // Sanity check:
156        // - right device type,
157        // - acked features is a subset of available ones,
158        // - right number of queues,
159        if self.device_type != expected_device_type
160            || (self.acked_features & !self.avail_features) != 0
161            || self.queues.len() != expected_num_queues
162        {
163            return Err(PersistError::InvalidInput);
164        }
165
166        let uses_notif_suppression = (self.acked_features & (1u64 << VIRTIO_RING_F_EVENT_IDX)) != 0;
167        let queue_construction_args = QueueConstructorArgs {
168            mem: mem.clone(),
169            is_activated: self.activated,
170        };
171        let queues: Vec<Queue> = self
172            .queues
173            .iter()
174            .map(|queue_state| {
175                Queue::restore(queue_construction_args.clone(), queue_state)
176                    .map(|mut queue| {
177                        if uses_notif_suppression {
178                            queue.enable_notif_suppression();
179                        }
180                        queue
181                    })
182                    .map_err(PersistError::QueueConstruction)
183            })
184            .collect::<Result<_, _>>()?;
185
186        for q in &queues {
187            // Sanity check queue size and queue max size.
188            if q.max_size != expected_queue_max_size {
189                return Err(PersistError::InvalidInput);
190            }
191        }
192        Ok(queues)
193    }
194}
195
196/// Transport information saved in snapshot.
197#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
198pub struct MmioTransportState {
199    // The register where feature bits are stored.
200    features_select: u32,
201    // The register where features page is selected.
202    acked_features_select: u32,
203    queue_select: u32,
204    device_status: u32,
205    config_generation: u32,
206    interrupt_status: u32,
207}
208
209impl MmioTransportState {
210    /// Apply transport state to an existing MMIO transport (NYX extension).
211    pub fn apply_to(&self, transport: &mut MmioTransport) {
212        transport.features_select = self.features_select;
213        transport.acked_features_select = self.acked_features_select;
214        transport.queue_select = self.queue_select;
215        transport.device_status = self.device_status;
216        transport.config_generation = self.config_generation;
217        transport
218            .interrupt
219            .irq_status
220            .store(self.interrupt_status, Ordering::Relaxed);
221    }
222}
223
224/// Auxiliary structure for initializing the transport when resuming from a snapshot.
225#[derive(Debug)]
226pub struct MmioTransportConstructorArgs {
227    /// Pointer to guest memory.
228    pub mem: GuestMemoryMmap,
229    /// Interrupt to use for the device
230    pub interrupt: Arc<IrqTrigger>,
231    /// Device associated with the current MMIO state.
232    pub device: Arc<Mutex<dyn VirtioDevice>>,
233    /// Is device backed by vhost-user.
234    pub is_vhost_user: bool,
235}
236
237impl Persist<'_> for MmioTransport {
238    type State = MmioTransportState;
239    type ConstructorArgs = MmioTransportConstructorArgs;
240    type Error = ();
241
242    fn save(&self) -> Self::State {
243        MmioTransportState {
244            features_select: self.features_select,
245            acked_features_select: self.acked_features_select,
246            queue_select: self.queue_select,
247            device_status: self.device_status,
248            config_generation: self.config_generation,
249            interrupt_status: self.interrupt.irq_status.load(Ordering::SeqCst),
250        }
251    }
252
253    fn restore(
254        constructor_args: Self::ConstructorArgs,
255        state: &Self::State,
256    ) -> Result<Self, Self::Error> {
257        let mut transport = MmioTransport::new(
258            constructor_args.mem,
259            constructor_args.interrupt,
260            constructor_args.device,
261            constructor_args.is_vhost_user,
262        );
263        transport.features_select = state.features_select;
264        transport.acked_features_select = state.acked_features_select;
265        transport.queue_select = state.queue_select;
266        transport.device_status = state.device_status;
267        transport.config_generation = state.config_generation;
268        transport
269            .interrupt
270            .irq_status
271            .store(state.interrupt_status, Ordering::SeqCst);
272        Ok(transport)
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use vmm_sys_util::tempfile::TempFile;
279
280    use super::*;
281    use crate::devices::virtio::block::virtio::VirtioBlock;
282    use crate::devices::virtio::block::virtio::device::FileEngineType;
283    use crate::devices::virtio::block::virtio::test_utils::default_block_with_path;
284    use crate::devices::virtio::net::Net;
285    use crate::devices::virtio::net::test_utils::default_net;
286    use crate::devices::virtio::test_utils::default_mem;
287    use crate::devices::virtio::transport::mmio::tests::DummyDevice;
288    use crate::devices::virtio::vsock::{Vsock, VsockUnixBackend};
289    use crate::snapshot::Snapshot;
290
291    const DEFAULT_QUEUE_MAX_SIZE: u16 = 256;
292    impl Default for QueueState {
293        fn default() -> QueueState {
294            QueueState {
295                max_size: DEFAULT_QUEUE_MAX_SIZE,
296                size: DEFAULT_QUEUE_MAX_SIZE,
297                ready: false,
298                desc_table: 0,
299                avail_ring: 0,
300                used_ring: 0,
301                next_avail: Wrapping(0),
302                next_used: Wrapping(0),
303                num_added: Wrapping(0),
304            }
305        }
306    }
307
308    #[test]
309    fn test_virtiodev_sanity_checks() {
310        let max_size = DEFAULT_QUEUE_MAX_SIZE;
311        let mut state = VirtioDeviceState::default();
312        let mem = default_mem();
313        // Valid checks.
314        state.build_queues_checked(&mem, 0, 0, max_size).unwrap();
315        // Invalid dev-type.
316        state
317            .build_queues_checked(&mem, 1, 0, max_size)
318            .unwrap_err();
319        // Invalid num-queues.
320        state
321            .build_queues_checked(&mem, 0, 1, max_size)
322            .unwrap_err();
323        // Unavailable features acked.
324        state.acked_features = 1;
325        state
326            .build_queues_checked(&mem, 0, 0, max_size)
327            .unwrap_err();
328
329        // Validate queue sanity checks.
330        let mut state = VirtioDeviceState::default();
331        let good_q = QueueState::default();
332        state.queues = vec![good_q];
333        // Valid.
334        state
335            .build_queues_checked(&mem, 0, state.queues.len(), max_size)
336            .unwrap();
337
338        // Invalid max queue size.
339        let bad_q = QueueState {
340            max_size: max_size + 1,
341            ..Default::default()
342        };
343        state.queues = vec![bad_q];
344        state
345            .build_queues_checked(&mem, 0, state.queues.len(), max_size)
346            .unwrap_err();
347
348        // Invalid: size > max.
349        let bad_q = QueueState {
350            size: max_size + 1,
351            ..Default::default()
352        };
353        state.queues = vec![bad_q];
354        state.activated = true;
355        state
356            .build_queues_checked(&mem, 0, state.queues.len(), max_size)
357            .unwrap_err();
358
359        // activated && !q.is_valid()
360        let bad_q = QueueState::default();
361        state.queues = vec![bad_q];
362        state.activated = true;
363        state
364            .build_queues_checked(&mem, 0, state.queues.len(), max_size)
365            .unwrap_err();
366    }
367
368    #[test]
369    fn test_queue_persistence() {
370        let mem = default_mem();
371
372        let mut queue = Queue::new(128);
373        queue.ready = true;
374        queue.size = queue.max_size;
375        queue.initialize(&mem).unwrap();
376
377        let mut bytes = vec![0; 4096];
378
379        Snapshot::new(queue.save())
380            .save(&mut bytes.as_mut_slice())
381            .unwrap();
382
383        let ca = QueueConstructorArgs {
384            mem,
385            is_activated: true,
386        };
387        let restored_queue = Queue::restore(
388            ca,
389            &Snapshot::load_without_crc_check(bytes.as_slice())
390                .unwrap()
391                .data,
392        )
393        .unwrap();
394
395        assert_eq!(restored_queue, queue);
396    }
397
398    #[test]
399    fn test_virtio_device_state_serde() {
400        let dummy = DummyDevice::new();
401        let mut mem = vec![0; 4096];
402
403        let state = VirtioDeviceState::from_device(&dummy);
404        Snapshot::new(&state).save(&mut mem.as_mut_slice()).unwrap();
405
406        let restored_state: VirtioDeviceState = Snapshot::load_without_crc_check(mem.as_slice())
407            .unwrap()
408            .data;
409        assert_eq!(restored_state, state);
410    }
411
412    impl PartialEq for MmioTransport {
413        fn eq(&self, other: &MmioTransport) -> bool {
414            let self_dev_type = self.device().lock().unwrap().device_type();
415            self.acked_features_select == other.acked_features_select &&
416                self.features_select == other.features_select &&
417                self.queue_select == other.queue_select &&
418                self.device_status == other.device_status &&
419                self.config_generation == other.config_generation &&
420                self.interrupt.irq_status.load(Ordering::SeqCst) == other.interrupt.irq_status.load(Ordering::SeqCst) &&
421                // Only checking equality of device type, actual device (de)ser is tested by that
422                // device's tests.
423                self_dev_type == other.device().lock().unwrap().device_type()
424        }
425    }
426
427    fn generic_mmiotransport_persistence_test(
428        mmio_transport: MmioTransport,
429        interrupt: Arc<IrqTrigger>,
430        mem: GuestMemoryMmap,
431        device: Arc<Mutex<dyn VirtioDevice>>,
432    ) {
433        let mut buf = vec![0; 4096];
434
435        Snapshot::new(mmio_transport.save())
436            .save(&mut buf.as_mut_slice())
437            .unwrap();
438
439        let restore_args = MmioTransportConstructorArgs {
440            mem,
441            interrupt,
442            device,
443            is_vhost_user: false,
444        };
445        let restored_mmio_transport = MmioTransport::restore(
446            restore_args,
447            &Snapshot::load_without_crc_check(buf.as_slice())
448                .unwrap()
449                .data,
450        )
451        .unwrap();
452
453        assert_eq!(restored_mmio_transport, mmio_transport);
454    }
455
456    fn create_default_block() -> (
457        MmioTransport,
458        Arc<IrqTrigger>,
459        GuestMemoryMmap,
460        Arc<Mutex<VirtioBlock>>,
461    ) {
462        let mem = default_mem();
463        let interrupt = Arc::new(IrqTrigger::new());
464
465        // Create backing file.
466        let f = TempFile::new().unwrap();
467        f.as_file().set_len(0x1000).unwrap();
468        let block = default_block_with_path(
469            f.as_path().to_str().unwrap().to_string(),
470            FileEngineType::default(),
471        );
472        let block = Arc::new(Mutex::new(block));
473        let mmio_transport =
474            MmioTransport::new(mem.clone(), interrupt.clone(), block.clone(), false);
475
476        (mmio_transport, interrupt, mem, block)
477    }
478
479    fn create_default_net() -> (
480        MmioTransport,
481        Arc<IrqTrigger>,
482        GuestMemoryMmap,
483        Arc<Mutex<Net>>,
484    ) {
485        let mem = default_mem();
486        let interrupt = Arc::new(IrqTrigger::new());
487        let net = Arc::new(Mutex::new(default_net()));
488        let mmio_transport = MmioTransport::new(mem.clone(), interrupt.clone(), net.clone(), false);
489
490        (mmio_transport, interrupt, mem, net)
491    }
492
493    fn default_vsock() -> (
494        MmioTransport,
495        Arc<IrqTrigger>,
496        GuestMemoryMmap,
497        Arc<Mutex<Vsock<VsockUnixBackend>>>,
498    ) {
499        let mem = default_mem();
500        let interrupt = Arc::new(IrqTrigger::new());
501
502        let guest_cid = 52;
503        let mut temp_uds_path = TempFile::new().unwrap();
504        // Remove the file so the path can be used by the socket.
505        temp_uds_path.remove().unwrap();
506        let uds_path = String::from(temp_uds_path.as_path().to_str().unwrap());
507        let backend = VsockUnixBackend::new(guest_cid, uds_path).unwrap();
508        let vsock = Vsock::new(guest_cid, backend).unwrap();
509        let vsock = Arc::new(Mutex::new(vsock));
510        let mmio_transport =
511            MmioTransport::new(mem.clone(), interrupt.clone(), vsock.clone(), false);
512
513        (mmio_transport, interrupt, mem, vsock)
514    }
515
516    #[test]
517    fn test_block_over_mmiotransport_persistence() {
518        let (mmio_transport, interrupt, mem, block) = create_default_block();
519        generic_mmiotransport_persistence_test(mmio_transport, interrupt, mem, block);
520    }
521
522    #[test]
523    fn test_net_over_mmiotransport_persistence() {
524        let (mmio_transport, interrupt, mem, net) = create_default_net();
525        generic_mmiotransport_persistence_test(mmio_transport, interrupt, mem, net);
526    }
527
528    #[test]
529    fn test_vsock_over_mmiotransport_persistence() {
530        let (mmio_transport, interrupt, mem, vsock) = default_vsock();
531        generic_mmiotransport_persistence_test(mmio_transport, interrupt, mem, vsock);
532    }
533}