1use std::num::Wrapping;
7use std::sync::atomic::Ordering;
8use std::sync::{Arc, Mutex};
9
10use serde::{Deserialize, Serialize};
11
12use super::queue::{InvalidAvailIdx, QueueError};
13use super::transport::mmio::IrqTrigger;
14use crate::devices::virtio::device::VirtioDevice;
15use crate::devices::virtio::generated::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
16use crate::devices::virtio::queue::Queue;
17use crate::devices::virtio::transport::mmio::MmioTransport;
18use crate::snapshot::Persist;
19use crate::vstate::memory::{GuestAddress, GuestMemoryMmap};
20
21#[derive(Debug, thiserror::Error, displaydoc::Display)]
23pub enum PersistError {
24 InvalidInput,
26 QueueConstruction(QueueError),
28 InvalidAvailIdx(#[from] InvalidAvailIdx),
30}
31
32#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
34pub struct QueueState {
35 max_size: u16,
37
38 size: u16,
40
41 ready: bool,
43
44 desc_table: u64,
46
47 avail_ring: u64,
49
50 used_ring: u64,
52
53 next_avail: Wrapping<u16>,
54 next_used: Wrapping<u16>,
55
56 num_added: Wrapping<u16>,
58}
59
60#[derive(Debug, Clone)]
62pub struct QueueConstructorArgs {
63 pub mem: GuestMemoryMmap,
65 pub is_activated: bool,
67}
68
69impl Persist<'_> for Queue {
70 type State = QueueState;
71 type ConstructorArgs = QueueConstructorArgs;
72 type Error = QueueError;
73
74 fn save(&self) -> Self::State {
75 QueueState {
76 max_size: self.max_size,
77 size: self.size,
78 ready: self.ready,
79 desc_table: self.desc_table_address.0,
80 avail_ring: self.avail_ring_address.0,
81 used_ring: self.used_ring_address.0,
82 next_avail: self.next_avail,
83 next_used: self.next_used,
84 num_added: self.num_added,
85 }
86 }
87
88 fn restore(
89 constructor_args: Self::ConstructorArgs,
90 state: &Self::State,
91 ) -> Result<Self, Self::Error> {
92 let mut queue = Queue {
93 max_size: state.max_size,
94 size: state.size,
95 ready: state.ready,
96 desc_table_address: GuestAddress(state.desc_table),
97 avail_ring_address: GuestAddress(state.avail_ring),
98 used_ring_address: GuestAddress(state.used_ring),
99
100 desc_table_ptr: std::ptr::null(),
101 avail_ring_ptr: std::ptr::null_mut(),
102 used_ring_ptr: std::ptr::null_mut(),
103
104 next_avail: state.next_avail,
105 next_used: state.next_used,
106 uses_notif_suppression: false,
107 num_added: state.num_added,
108 };
109 if constructor_args.is_activated {
110 queue.initialize(&constructor_args.mem)?;
111 }
112 Ok(queue)
113 }
114}
115
116#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
118pub struct VirtioDeviceState {
119 pub device_type: u32,
121 pub avail_features: u64,
123 pub acked_features: u64,
125 pub interrupt_status: u32,
127 pub queues: Vec<QueueState>,
129 pub activated: bool,
131}
132
133impl VirtioDeviceState {
134 pub fn from_device(device: &dyn VirtioDevice) -> Self {
136 VirtioDeviceState {
137 device_type: device.device_type(),
138 avail_features: device.avail_features(),
139 acked_features: device.acked_features(),
140 interrupt_status: device.interrupt_status().load(Ordering::Relaxed),
141 queues: device.queues().iter().map(Persist::save).collect(),
142 activated: device.is_activated(),
143 }
144 }
145
146 pub fn build_queues_checked(
149 &self,
150 mem: &GuestMemoryMmap,
151 expected_device_type: u32,
152 expected_num_queues: usize,
153 expected_queue_max_size: u16,
154 ) -> Result<Vec<Queue>, PersistError> {
155 if self.device_type != expected_device_type
160 || (self.acked_features & !self.avail_features) != 0
161 || self.queues.len() != expected_num_queues
162 {
163 return Err(PersistError::InvalidInput);
164 }
165
166 let uses_notif_suppression = (self.acked_features & (1u64 << VIRTIO_RING_F_EVENT_IDX)) != 0;
167 let queue_construction_args = QueueConstructorArgs {
168 mem: mem.clone(),
169 is_activated: self.activated,
170 };
171 let queues: Vec<Queue> = self
172 .queues
173 .iter()
174 .map(|queue_state| {
175 Queue::restore(queue_construction_args.clone(), queue_state)
176 .map(|mut queue| {
177 if uses_notif_suppression {
178 queue.enable_notif_suppression();
179 }
180 queue
181 })
182 .map_err(PersistError::QueueConstruction)
183 })
184 .collect::<Result<_, _>>()?;
185
186 for q in &queues {
187 if q.max_size != expected_queue_max_size {
189 return Err(PersistError::InvalidInput);
190 }
191 }
192 Ok(queues)
193 }
194}
195
196#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
198pub struct MmioTransportState {
199 features_select: u32,
201 acked_features_select: u32,
203 queue_select: u32,
204 device_status: u32,
205 config_generation: u32,
206 interrupt_status: u32,
207}
208
209impl MmioTransportState {
210 pub fn apply_to(&self, transport: &mut MmioTransport) {
212 transport.features_select = self.features_select;
213 transport.acked_features_select = self.acked_features_select;
214 transport.queue_select = self.queue_select;
215 transport.device_status = self.device_status;
216 transport.config_generation = self.config_generation;
217 transport
218 .interrupt
219 .irq_status
220 .store(self.interrupt_status, Ordering::Relaxed);
221 }
222}
223
224#[derive(Debug)]
226pub struct MmioTransportConstructorArgs {
227 pub mem: GuestMemoryMmap,
229 pub interrupt: Arc<IrqTrigger>,
231 pub device: Arc<Mutex<dyn VirtioDevice>>,
233 pub is_vhost_user: bool,
235}
236
237impl Persist<'_> for MmioTransport {
238 type State = MmioTransportState;
239 type ConstructorArgs = MmioTransportConstructorArgs;
240 type Error = ();
241
242 fn save(&self) -> Self::State {
243 MmioTransportState {
244 features_select: self.features_select,
245 acked_features_select: self.acked_features_select,
246 queue_select: self.queue_select,
247 device_status: self.device_status,
248 config_generation: self.config_generation,
249 interrupt_status: self.interrupt.irq_status.load(Ordering::SeqCst),
250 }
251 }
252
253 fn restore(
254 constructor_args: Self::ConstructorArgs,
255 state: &Self::State,
256 ) -> Result<Self, Self::Error> {
257 let mut transport = MmioTransport::new(
258 constructor_args.mem,
259 constructor_args.interrupt,
260 constructor_args.device,
261 constructor_args.is_vhost_user,
262 );
263 transport.features_select = state.features_select;
264 transport.acked_features_select = state.acked_features_select;
265 transport.queue_select = state.queue_select;
266 transport.device_status = state.device_status;
267 transport.config_generation = state.config_generation;
268 transport
269 .interrupt
270 .irq_status
271 .store(state.interrupt_status, Ordering::SeqCst);
272 Ok(transport)
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use vmm_sys_util::tempfile::TempFile;
279
280 use super::*;
281 use crate::devices::virtio::block::virtio::VirtioBlock;
282 use crate::devices::virtio::block::virtio::device::FileEngineType;
283 use crate::devices::virtio::block::virtio::test_utils::default_block_with_path;
284 use crate::devices::virtio::net::Net;
285 use crate::devices::virtio::net::test_utils::default_net;
286 use crate::devices::virtio::test_utils::default_mem;
287 use crate::devices::virtio::transport::mmio::tests::DummyDevice;
288 use crate::devices::virtio::vsock::{Vsock, VsockUnixBackend};
289 use crate::snapshot::Snapshot;
290
291 const DEFAULT_QUEUE_MAX_SIZE: u16 = 256;
292 impl Default for QueueState {
293 fn default() -> QueueState {
294 QueueState {
295 max_size: DEFAULT_QUEUE_MAX_SIZE,
296 size: DEFAULT_QUEUE_MAX_SIZE,
297 ready: false,
298 desc_table: 0,
299 avail_ring: 0,
300 used_ring: 0,
301 next_avail: Wrapping(0),
302 next_used: Wrapping(0),
303 num_added: Wrapping(0),
304 }
305 }
306 }
307
308 #[test]
309 fn test_virtiodev_sanity_checks() {
310 let max_size = DEFAULT_QUEUE_MAX_SIZE;
311 let mut state = VirtioDeviceState::default();
312 let mem = default_mem();
313 state.build_queues_checked(&mem, 0, 0, max_size).unwrap();
315 state
317 .build_queues_checked(&mem, 1, 0, max_size)
318 .unwrap_err();
319 state
321 .build_queues_checked(&mem, 0, 1, max_size)
322 .unwrap_err();
323 state.acked_features = 1;
325 state
326 .build_queues_checked(&mem, 0, 0, max_size)
327 .unwrap_err();
328
329 let mut state = VirtioDeviceState::default();
331 let good_q = QueueState::default();
332 state.queues = vec![good_q];
333 state
335 .build_queues_checked(&mem, 0, state.queues.len(), max_size)
336 .unwrap();
337
338 let bad_q = QueueState {
340 max_size: max_size + 1,
341 ..Default::default()
342 };
343 state.queues = vec![bad_q];
344 state
345 .build_queues_checked(&mem, 0, state.queues.len(), max_size)
346 .unwrap_err();
347
348 let bad_q = QueueState {
350 size: max_size + 1,
351 ..Default::default()
352 };
353 state.queues = vec![bad_q];
354 state.activated = true;
355 state
356 .build_queues_checked(&mem, 0, state.queues.len(), max_size)
357 .unwrap_err();
358
359 let bad_q = QueueState::default();
361 state.queues = vec![bad_q];
362 state.activated = true;
363 state
364 .build_queues_checked(&mem, 0, state.queues.len(), max_size)
365 .unwrap_err();
366 }
367
368 #[test]
369 fn test_queue_persistence() {
370 let mem = default_mem();
371
372 let mut queue = Queue::new(128);
373 queue.ready = true;
374 queue.size = queue.max_size;
375 queue.initialize(&mem).unwrap();
376
377 let mut bytes = vec![0; 4096];
378
379 Snapshot::new(queue.save())
380 .save(&mut bytes.as_mut_slice())
381 .unwrap();
382
383 let ca = QueueConstructorArgs {
384 mem,
385 is_activated: true,
386 };
387 let restored_queue = Queue::restore(
388 ca,
389 &Snapshot::load_without_crc_check(bytes.as_slice())
390 .unwrap()
391 .data,
392 )
393 .unwrap();
394
395 assert_eq!(restored_queue, queue);
396 }
397
398 #[test]
399 fn test_virtio_device_state_serde() {
400 let dummy = DummyDevice::new();
401 let mut mem = vec![0; 4096];
402
403 let state = VirtioDeviceState::from_device(&dummy);
404 Snapshot::new(&state).save(&mut mem.as_mut_slice()).unwrap();
405
406 let restored_state: VirtioDeviceState = Snapshot::load_without_crc_check(mem.as_slice())
407 .unwrap()
408 .data;
409 assert_eq!(restored_state, state);
410 }
411
412 impl PartialEq for MmioTransport {
413 fn eq(&self, other: &MmioTransport) -> bool {
414 let self_dev_type = self.device().lock().unwrap().device_type();
415 self.acked_features_select == other.acked_features_select &&
416 self.features_select == other.features_select &&
417 self.queue_select == other.queue_select &&
418 self.device_status == other.device_status &&
419 self.config_generation == other.config_generation &&
420 self.interrupt.irq_status.load(Ordering::SeqCst) == other.interrupt.irq_status.load(Ordering::SeqCst) &&
421 self_dev_type == other.device().lock().unwrap().device_type()
424 }
425 }
426
427 fn generic_mmiotransport_persistence_test(
428 mmio_transport: MmioTransport,
429 interrupt: Arc<IrqTrigger>,
430 mem: GuestMemoryMmap,
431 device: Arc<Mutex<dyn VirtioDevice>>,
432 ) {
433 let mut buf = vec![0; 4096];
434
435 Snapshot::new(mmio_transport.save())
436 .save(&mut buf.as_mut_slice())
437 .unwrap();
438
439 let restore_args = MmioTransportConstructorArgs {
440 mem,
441 interrupt,
442 device,
443 is_vhost_user: false,
444 };
445 let restored_mmio_transport = MmioTransport::restore(
446 restore_args,
447 &Snapshot::load_without_crc_check(buf.as_slice())
448 .unwrap()
449 .data,
450 )
451 .unwrap();
452
453 assert_eq!(restored_mmio_transport, mmio_transport);
454 }
455
456 fn create_default_block() -> (
457 MmioTransport,
458 Arc<IrqTrigger>,
459 GuestMemoryMmap,
460 Arc<Mutex<VirtioBlock>>,
461 ) {
462 let mem = default_mem();
463 let interrupt = Arc::new(IrqTrigger::new());
464
465 let f = TempFile::new().unwrap();
467 f.as_file().set_len(0x1000).unwrap();
468 let block = default_block_with_path(
469 f.as_path().to_str().unwrap().to_string(),
470 FileEngineType::default(),
471 );
472 let block = Arc::new(Mutex::new(block));
473 let mmio_transport =
474 MmioTransport::new(mem.clone(), interrupt.clone(), block.clone(), false);
475
476 (mmio_transport, interrupt, mem, block)
477 }
478
479 fn create_default_net() -> (
480 MmioTransport,
481 Arc<IrqTrigger>,
482 GuestMemoryMmap,
483 Arc<Mutex<Net>>,
484 ) {
485 let mem = default_mem();
486 let interrupt = Arc::new(IrqTrigger::new());
487 let net = Arc::new(Mutex::new(default_net()));
488 let mmio_transport = MmioTransport::new(mem.clone(), interrupt.clone(), net.clone(), false);
489
490 (mmio_transport, interrupt, mem, net)
491 }
492
493 fn default_vsock() -> (
494 MmioTransport,
495 Arc<IrqTrigger>,
496 GuestMemoryMmap,
497 Arc<Mutex<Vsock<VsockUnixBackend>>>,
498 ) {
499 let mem = default_mem();
500 let interrupt = Arc::new(IrqTrigger::new());
501
502 let guest_cid = 52;
503 let mut temp_uds_path = TempFile::new().unwrap();
504 temp_uds_path.remove().unwrap();
506 let uds_path = String::from(temp_uds_path.as_path().to_str().unwrap());
507 let backend = VsockUnixBackend::new(guest_cid, uds_path).unwrap();
508 let vsock = Vsock::new(guest_cid, backend).unwrap();
509 let vsock = Arc::new(Mutex::new(vsock));
510 let mmio_transport =
511 MmioTransport::new(mem.clone(), interrupt.clone(), vsock.clone(), false);
512
513 (mmio_transport, interrupt, mem, vsock)
514 }
515
516 #[test]
517 fn test_block_over_mmiotransport_persistence() {
518 let (mmio_transport, interrupt, mem, block) = create_default_block();
519 generic_mmiotransport_persistence_test(mmio_transport, interrupt, mem, block);
520 }
521
522 #[test]
523 fn test_net_over_mmiotransport_persistence() {
524 let (mmio_transport, interrupt, mem, net) = create_default_net();
525 generic_mmiotransport_persistence_test(mmio_transport, interrupt, mem, net);
526 }
527
528 #[test]
529 fn test_vsock_over_mmiotransport_persistence() {
530 let (mmio_transport, interrupt, mem, vsock) = default_vsock();
531 generic_mmiotransport_persistence_test(mmio_transport, interrupt, mem, vsock);
532 }
533}