1use std::os::fd::AsRawFd;
8use std::os::unix::net::UnixStream;
9use std::sync::Arc;
10
11use vhost::vhost_user::message::*;
12use vhost::vhost_user::{Frontend, VhostUserFrontend};
13use vhost::{Error as VhostError, VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
14use vm_memory::{Address, GuestMemory, GuestMemoryError, GuestMemoryRegion};
15use vmm_sys_util::eventfd::EventFd;
16
17use crate::devices::virtio::queue::Queue;
18use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType};
19use crate::vstate::memory::GuestMemoryMmap;
20
21#[derive(Debug, thiserror::Error, displaydoc::Display)]
23pub enum VhostUserError {
24 AvailAddress(GuestMemoryError),
26 Connect(#[from] std::io::Error),
28 DescriptorTableAddress(GuestMemoryError),
30 VhostUserGetFeatures(VhostError),
32 VhostUserGetProtocolFeatures(VhostError),
34 VhostUserSetOwner(VhostError),
36 VhostUserSetFeatures(VhostError),
38 VhostUserSetProtocolFeatures(VhostError),
40 VhostUserSetMemTable(VhostError),
42 VhostUserSetVringNum(VhostError),
44 VhostUserSetVringAddr(VhostError),
46 VhostUserSetVringBase(VhostError),
48 VhostUserSetVringCall(VhostError),
50 VhostUserSetVringKick(VhostError),
52 VhostUserSetVringEnable(VhostError),
54 VhostUserNoMemoryRegion,
56 UsedAddress(GuestMemoryError),
58}
59
60pub trait VhostUserHandleBackend: Sized {
65 fn from_stream(_sock: UnixStream, _max_queue_num: u64) -> Self {
67 unimplemented!()
68 }
69
70 fn set_hdr_flags(&self, _flags: VhostUserHeaderFlag) {
71 unimplemented!()
72 }
73
74 fn get_features(&self) -> Result<u64, vhost::Error> {
76 unimplemented!()
77 }
78
79 fn set_features(&self, _features: u64) -> Result<(), vhost::Error> {
81 unimplemented!()
82 }
83
84 fn set_owner(&self) -> Result<(), vhost::Error> {
86 unimplemented!()
87 }
88
89 fn set_mem_table(&self, _regions: &[VhostUserMemoryRegionInfo]) -> Result<(), vhost::Error> {
92 unimplemented!()
93 }
94
95 fn set_vring_num(&self, _queue_index: usize, _num: u16) -> Result<(), vhost::Error> {
97 unimplemented!()
98 }
99
100 fn set_vring_addr(
102 &self,
103 _queue_index: usize,
104 _config_data: &VringConfigData,
105 ) -> Result<(), vhost::Error> {
106 unimplemented!()
107 }
108
109 fn set_vring_base(&self, _queue_index: usize, _base: u16) -> Result<(), vhost::Error> {
111 unimplemented!()
112 }
113
114 fn set_vring_call(&self, _queue_index: usize, _fd: &EventFd) -> Result<(), vhost::Error> {
119 unimplemented!()
120 }
121
122 fn set_vring_kick(&self, _queue_index: usize, _fd: &EventFd) -> Result<(), vhost::Error> {
127 unimplemented!()
128 }
129
130 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures, vhost::Error> {
131 unimplemented!()
132 }
133
134 fn set_protocol_features(
135 &mut self,
136 _features: VhostUserProtocolFeatures,
137 ) -> Result<(), vhost::Error> {
138 unimplemented!()
139 }
140
141 fn set_vring_enable(&mut self, _queue_index: usize, _enable: bool) -> Result<(), vhost::Error> {
142 unimplemented!()
143 }
144
145 fn get_config(
146 &mut self,
147 _offset: u32,
148 _size: u32,
149 _flags: VhostUserConfigFlags,
150 _buf: &[u8],
151 ) -> Result<(VhostUserConfig, VhostUserConfigPayload), vhost::Error> {
152 unimplemented!()
153 }
154
155 fn set_config(
156 &mut self,
157 _offset: u32,
158 _flags: VhostUserConfigFlags,
159 _buf: &[u8],
160 ) -> Result<(), vhost::Error> {
161 unimplemented!()
162 }
163}
164
165impl VhostUserHandleBackend for Frontend {
166 fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
167 Frontend::from_stream(sock, max_queue_num)
168 }
169
170 fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) {
171 self.set_hdr_flags(flags)
172 }
173
174 fn get_features(&self) -> Result<u64, vhost::Error> {
176 <Frontend as VhostBackend>::get_features(self)
177 }
178
179 fn set_features(&self, features: u64) -> Result<(), vhost::Error> {
181 <Frontend as VhostBackend>::set_features(self, features)
182 }
183
184 fn set_owner(&self) -> Result<(), vhost::Error> {
186 <Frontend as VhostBackend>::set_owner(self)
187 }
188
189 fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<(), vhost::Error> {
192 <Frontend as VhostBackend>::set_mem_table(self, regions)
193 }
194
195 fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<(), vhost::Error> {
197 <Frontend as VhostBackend>::set_vring_num(self, queue_index, num)
198 }
199
200 fn set_vring_addr(
202 &self,
203 queue_index: usize,
204 config_data: &VringConfigData,
205 ) -> Result<(), vhost::Error> {
206 <Frontend as VhostBackend>::set_vring_addr(self, queue_index, config_data)
207 }
208
209 fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<(), vhost::Error> {
211 <Frontend as VhostBackend>::set_vring_base(self, queue_index, base)
212 }
213
214 fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<(), vhost::Error> {
219 <Frontend as VhostBackend>::set_vring_call(self, queue_index, fd)
220 }
221
222 fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<(), vhost::Error> {
227 <Frontend as VhostBackend>::set_vring_kick(self, queue_index, fd)
228 }
229
230 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures, vhost::Error> {
231 <Frontend as VhostUserFrontend>::get_protocol_features(self)
232 }
233
234 fn set_protocol_features(
235 &mut self,
236 features: VhostUserProtocolFeatures,
237 ) -> Result<(), vhost::Error> {
238 <Frontend as VhostUserFrontend>::set_protocol_features(self, features)
239 }
240
241 fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<(), vhost::Error> {
242 <Frontend as VhostUserFrontend>::set_vring_enable(self, queue_index, enable)
243 }
244
245 fn get_config(
246 &mut self,
247 offset: u32,
248 size: u32,
249 flags: VhostUserConfigFlags,
250 buf: &[u8],
251 ) -> Result<(VhostUserConfig, VhostUserConfigPayload), vhost::Error> {
252 <Frontend as VhostUserFrontend>::get_config(self, offset, size, flags, buf)
253 }
254
255 fn set_config(
256 &mut self,
257 offset: u32,
258 flags: VhostUserConfigFlags,
259 buf: &[u8],
260 ) -> Result<(), vhost::Error> {
261 <Frontend as VhostUserFrontend>::set_config(self, offset, flags, buf)
262 }
263}
264
265pub type VhostUserHandle = VhostUserHandleImpl<Frontend>;
266
267#[derive(Clone)]
269pub struct VhostUserHandleImpl<T: VhostUserHandleBackend> {
270 pub vu: T,
271 pub socket_path: String,
272}
273
274impl<T: VhostUserHandleBackend> std::fmt::Debug for VhostUserHandleImpl<T> {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 f.debug_struct("VhostUserHandle")
277 .field("socket_path", &self.socket_path)
278 .finish()
279 }
280}
281
282impl<T: VhostUserHandleBackend> VhostUserHandleImpl<T> {
283 pub fn new(socket_path: &str, num_queues: u64) -> Result<Self, VhostUserError> {
286 let stream = UnixStream::connect(socket_path).map_err(VhostUserError::Connect)?;
287
288 let vu = T::from_stream(stream, num_queues);
289 vu.set_owner().map_err(VhostUserError::VhostUserSetOwner)?;
290
291 Ok(Self {
292 vu,
293 socket_path: socket_path.to_string(),
294 })
295 }
296
297 pub fn set_features(&self, features: u64) -> Result<(), VhostUserError> {
299 self.vu
300 .set_features(features)
301 .map_err(VhostUserError::VhostUserSetFeatures)
302 }
303
304 pub fn set_protocol_features(
306 &mut self,
307 acked_features: u64,
308 acked_protocol_features: u64,
309 ) -> Result<(), VhostUserError> {
310 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0
311 && let Some(acked_protocol_features) =
312 VhostUserProtocolFeatures::from_bits(acked_protocol_features)
313 {
314 self.vu
315 .set_protocol_features(acked_protocol_features)
316 .map_err(VhostUserError::VhostUserSetProtocolFeatures)?;
317
318 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
319 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
320 }
321 }
322
323 Ok(())
324 }
325
326 pub fn negotiate_features(
328 &mut self,
329 avail_features: u64,
330 avail_protocol_features: VhostUserProtocolFeatures,
331 ) -> Result<(u64, u64), VhostUserError> {
332 let backend_features = self
335 .vu
336 .get_features()
337 .map_err(VhostUserError::VhostUserGetFeatures)?;
338 let acked_features = avail_features & backend_features;
339
340 let acked_protocol_features =
341 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
343 let backend_protocol_features = self
344 .vu
345 .get_protocol_features()
346 .map_err(VhostUserError::VhostUserGetProtocolFeatures)?;
347
348 let acked_protocol_features = avail_protocol_features & backend_protocol_features;
349
350 self.vu
351 .set_protocol_features(acked_protocol_features)
352 .map_err(VhostUserError::VhostUserSetProtocolFeatures)?;
353
354 acked_protocol_features
355 } else {
356 VhostUserProtocolFeatures::empty()
357 };
358
359 if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
360 self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
361 }
362
363 Ok((acked_features, acked_protocol_features.bits()))
364 }
365
366 fn update_mem_table(&self, mem: &GuestMemoryMmap) -> Result<(), VhostUserError> {
368 let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
369
370 for region in mem.iter() {
371 let (mmap_handle, mmap_offset) = match region.file_offset() {
372 Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
373 None => {
374 return Err(VhostUserError::VhostUserNoMemoryRegion);
375 }
376 };
377
378 let vhost_user_net_reg = VhostUserMemoryRegionInfo {
379 guest_phys_addr: region.start_addr().raw_value(),
380 memory_size: region.len(),
381 userspace_addr: region.inner.as_ptr() as u64,
382 mmap_offset,
383 mmap_handle,
384 };
385 regions.push(vhost_user_net_reg);
386 }
387
388 self.vu
389 .set_mem_table(regions.as_slice())
390 .map_err(VhostUserError::VhostUserSetMemTable)?;
391
392 Ok(())
393 }
394
395 pub fn setup_backend(
398 &mut self,
399 mem: &GuestMemoryMmap,
400 queues: &[(usize, &Queue, &EventFd)],
401 interrupt: Arc<dyn VirtioInterrupt>,
402 ) -> Result<(), VhostUserError> {
403 self.update_mem_table(mem)?;
405
406 for (queue_index, queue, _) in queues.iter() {
410 self.vu
411 .set_vring_num(*queue_index, queue.size)
412 .map_err(VhostUserError::VhostUserSetVringNum)?;
413 }
414
415 for (queue_index, queue, queue_evt) in queues.iter() {
416 let config_data = VringConfigData {
417 queue_max_size: queue.max_size,
418 queue_size: queue.size,
419 flags: 0u32,
420 desc_table_addr: mem
421 .get_host_address(queue.desc_table_address)
422 .map_err(VhostUserError::DescriptorTableAddress)?
423 as u64,
424 used_ring_addr: mem
425 .get_host_address(queue.used_ring_address)
426 .map_err(VhostUserError::UsedAddress)? as u64,
427 avail_ring_addr: mem
428 .get_host_address(queue.avail_ring_address)
429 .map_err(VhostUserError::AvailAddress)? as u64,
430 log_addr: None,
431 };
432
433 self.vu
434 .set_vring_addr(*queue_index, &config_data)
435 .map_err(VhostUserError::VhostUserSetVringAddr)?;
436 self.vu
437 .set_vring_base(*queue_index, queue.avail_ring_idx_get())
438 .map_err(VhostUserError::VhostUserSetVringBase)?;
439
440 self.vu
443 .set_vring_call(
444 *queue_index,
445 interrupt
446 .notifier(VirtioInterruptType::Queue(
447 (*queue_index).try_into().unwrap_or_else(|_| {
448 panic!("vhost-user: invalid queue index: {}", *queue_index)
449 }),
450 ))
451 .as_ref()
452 .unwrap(),
453 )
454 .map_err(VhostUserError::VhostUserSetVringCall)?;
455
456 self.vu
457 .set_vring_kick(*queue_index, queue_evt)
458 .map_err(VhostUserError::VhostUserSetVringKick)?;
459
460 self.vu
461 .set_vring_enable(*queue_index, true)
462 .map_err(VhostUserError::VhostUserSetVringEnable)?;
463 }
464
465 Ok(())
466 }
467}
468
469#[cfg(test)]
470pub(crate) mod tests {
471 #![allow(clippy::undocumented_unsafe_blocks)]
472
473 use std::fs::File;
474
475 use vmm_sys_util::tempfile::TempFile;
476
477 use super::*;
478 use crate::devices::virtio::test_utils::default_interrupt;
479 use crate::test_utils::create_tmp_socket;
480 use crate::vstate::memory;
481 use crate::vstate::memory::{GuestAddress, GuestRegionMmapExt};
482
483 pub(crate) fn create_mem(file: File, regions: &[(GuestAddress, usize)]) -> GuestMemoryMmap {
484 GuestMemoryMmap::from_regions(
485 memory::create(
486 regions.iter().copied(),
487 libc::MAP_PRIVATE,
488 Some(file),
489 false,
490 )
491 .unwrap()
492 .into_iter()
493 .map(|region| GuestRegionMmapExt::dram_from_mmap_region(region, 0))
494 .collect(),
495 )
496 .unwrap()
497 }
498
499 #[test]
500 fn test_new() {
501 struct MockFrontend {
502 sock: UnixStream,
503 max_queue_num: u64,
504 is_owner: std::cell::UnsafeCell<bool>,
505 }
506
507 impl VhostUserHandleBackend for MockFrontend {
508 fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
509 Self {
510 sock,
511 max_queue_num,
512 is_owner: std::cell::UnsafeCell::new(false),
513 }
514 }
515
516 fn set_owner(&self) -> Result<(), vhost::Error> {
517 unsafe { *self.is_owner.get() = true };
518 Ok(())
519 }
520 }
521
522 let max_queue_num = 69;
523
524 let (_tmp_dir, tmp_socket_path) = create_tmp_socket();
525
526 let vuh =
529 VhostUserHandleImpl::<MockFrontend>::new(&tmp_socket_path, max_queue_num).unwrap();
530 assert_eq!(
531 vuh.vu
532 .sock
533 .peer_addr()
534 .unwrap()
535 .as_pathname()
536 .unwrap()
537 .to_str()
538 .unwrap(),
539 &tmp_socket_path,
540 );
541 assert_eq!(vuh.vu.max_queue_num, max_queue_num);
542 assert!(unsafe { *vuh.vu.is_owner.get() });
543 }
544
545 #[test]
546 fn test_set_features() {
547 struct MockFrontend {
548 features: std::cell::UnsafeCell<u64>,
549 }
550
551 impl VhostUserHandleBackend for MockFrontend {
552 fn set_features(&self, features: u64) -> Result<(), vhost::Error> {
553 unsafe { *self.features.get() = features };
554 Ok(())
555 }
556 }
557
558 let vuh = VhostUserHandleImpl {
560 vu: MockFrontend { features: 0.into() },
561 socket_path: "".to_string(),
562 };
563 vuh.set_features(0x69).unwrap();
564 assert_eq!(unsafe { *vuh.vu.features.get() }, 0x69);
565 }
566
567 #[test]
568 fn test_set_protocol_features() {
569 struct MockFrontend {
570 protocol_features: VhostUserProtocolFeatures,
571 hdr_flags: std::cell::UnsafeCell<VhostUserHeaderFlag>,
572 }
573
574 impl VhostUserHandleBackend for MockFrontend {
575 fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) {
576 unsafe { *self.hdr_flags.get() = flags };
577 }
578
579 fn set_protocol_features(
580 &mut self,
581 features: VhostUserProtocolFeatures,
582 ) -> Result<(), vhost::Error> {
583 self.protocol_features = features;
584 Ok(())
585 }
586 }
587
588 let mut vuh = VhostUserHandleImpl {
589 vu: MockFrontend {
590 protocol_features: VhostUserProtocolFeatures::empty(),
591 hdr_flags: std::cell::UnsafeCell::new(VhostUserHeaderFlag::empty()),
592 },
593 socket_path: "".to_string(),
594 };
595
596 let acked_features = 0;
598 let acked_protocol_features = VhostUserProtocolFeatures::empty();
599 vuh.set_protocol_features(acked_features, acked_protocol_features.bits())
600 .unwrap();
601 assert_eq!(vuh.vu.protocol_features, VhostUserProtocolFeatures::empty());
602 assert_eq!(
603 unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
604 VhostUserHeaderFlag::empty().bits()
605 );
606
607 let acked_features = 0;
609 let acked_protocol_features = VhostUserProtocolFeatures::all();
610 vuh.set_protocol_features(acked_features, acked_protocol_features.bits())
611 .unwrap();
612 assert_eq!(vuh.vu.protocol_features, VhostUserProtocolFeatures::empty());
613 assert_eq!(
614 unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
615 VhostUserHeaderFlag::empty().bits()
616 );
617
618 let acked_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
620 let mut acked_protocol_features = VhostUserProtocolFeatures::all();
621 acked_protocol_features.set(VhostUserProtocolFeatures::REPLY_ACK, false);
622 vuh.set_protocol_features(acked_features, acked_protocol_features.bits())
623 .unwrap();
624 assert_eq!(vuh.vu.protocol_features, acked_protocol_features);
625 assert_eq!(
626 unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
627 VhostUserHeaderFlag::empty().bits()
628 );
629
630 let acked_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
632 let acked_protocol_features = VhostUserProtocolFeatures::all();
633 vuh.set_protocol_features(acked_features, acked_protocol_features.bits())
634 .unwrap();
635 assert_eq!(vuh.vu.protocol_features, acked_protocol_features);
636 assert_eq!(
637 unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
638 VhostUserHeaderFlag::NEED_REPLY.bits()
639 );
640 }
641
642 #[test]
643 fn test_negotiate_features() {
644 struct MockFrontend {
645 features: u64,
646 protocol_features: VhostUserProtocolFeatures,
647 hdr_flags: std::cell::UnsafeCell<VhostUserHeaderFlag>,
648 }
649
650 impl VhostUserHandleBackend for MockFrontend {
651 fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) {
652 unsafe { *self.hdr_flags.get() = flags };
653 }
654
655 fn get_features(&self) -> Result<u64, vhost::Error> {
656 Ok(self.features)
657 }
658
659 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures, vhost::Error> {
660 Ok(self.protocol_features)
661 }
662
663 fn set_protocol_features(
664 &mut self,
665 features: VhostUserProtocolFeatures,
666 ) -> Result<(), vhost::Error> {
667 self.protocol_features = features;
668 Ok(())
669 }
670 }
671
672 let mut vuh = VhostUserHandleImpl {
673 vu: MockFrontend {
674 features: 0,
675 protocol_features: VhostUserProtocolFeatures::empty(),
676 hdr_flags: std::cell::UnsafeCell::new(VhostUserHeaderFlag::empty()),
677 },
678 socket_path: "".to_string(),
679 };
680
681 let avail_features = 0;
683 let avail_protocol_features = VhostUserProtocolFeatures::empty();
684 let (acked_features, acked_protocol_features) = vuh
685 .negotiate_features(avail_features, avail_protocol_features)
686 .unwrap();
687 assert_eq!(acked_features, avail_features);
688 assert_eq!(acked_protocol_features, avail_protocol_features.bits());
689 assert_eq!(vuh.vu.protocol_features, VhostUserProtocolFeatures::empty());
690 assert_eq!(
691 unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
692 VhostUserHeaderFlag::empty().bits()
693 );
694
695 let mut avail_features = VhostUserVirtioFeatures::all();
698 avail_features.set(VhostUserVirtioFeatures::PROTOCOL_FEATURES, false);
699
700 vuh.vu.features = avail_features.bits();
702
703 let avail_protocol_features = VhostUserProtocolFeatures::empty();
704 let (acked_features, acked_protocol_features) = vuh
705 .negotiate_features(avail_features.bits(), avail_protocol_features)
706 .unwrap();
707 assert_eq!(acked_features, avail_features.bits());
708 assert_eq!(acked_protocol_features, avail_protocol_features.bits());
709 assert_eq!(vuh.vu.protocol_features, VhostUserProtocolFeatures::empty());
710 assert_eq!(
711 unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
712 VhostUserHeaderFlag::empty().bits()
713 );
714
715 let avail_features = VhostUserVirtioFeatures::all();
717 vuh.vu.features = avail_features.bits();
719
720 let mut avail_protocol_features = VhostUserProtocolFeatures::empty();
721 avail_protocol_features.set(VhostUserProtocolFeatures::CONFIG, true);
722
723 let mut backend_protocol_features = VhostUserProtocolFeatures::empty();
724 backend_protocol_features.set(VhostUserProtocolFeatures::CONFIG, true);
725 backend_protocol_features.set(VhostUserProtocolFeatures::PAGEFAULT, true);
726 vuh.vu.protocol_features = backend_protocol_features;
727
728 let (acked_features, acked_protocol_features) = vuh
729 .negotiate_features(avail_features.bits(), avail_protocol_features)
730 .unwrap();
731 assert_eq!(acked_features, avail_features.bits());
732 assert_eq!(acked_protocol_features, avail_protocol_features.bits());
733 assert_eq!(vuh.vu.protocol_features, avail_protocol_features);
734 assert_eq!(
735 unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
736 VhostUserHeaderFlag::empty().bits()
737 );
738
739 let avail_features = VhostUserVirtioFeatures::all();
741 vuh.vu.features = avail_features.bits();
743
744 let mut avail_protocol_features = VhostUserProtocolFeatures::empty();
745 avail_protocol_features.set(VhostUserProtocolFeatures::REPLY_ACK, true);
746
747 vuh.vu.protocol_features = avail_protocol_features;
749
750 let (acked_features, acked_protocol_features) = vuh
751 .negotiate_features(avail_features.bits(), avail_protocol_features)
752 .unwrap();
753 assert_eq!(acked_features, avail_features.bits());
754 assert_eq!(acked_protocol_features, avail_protocol_features.bits());
755 assert_eq!(vuh.vu.protocol_features, avail_protocol_features);
756 assert_eq!(
757 unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
758 VhostUserHeaderFlag::NEED_REPLY.bits(),
759 );
760 }
761
762 #[test]
763 fn test_update_mem_table() {
764 struct MockFrontend {
765 regions: std::cell::UnsafeCell<Vec<VhostUserMemoryRegionInfo>>,
766 }
767
768 impl VhostUserHandleBackend for MockFrontend {
769 fn set_mem_table(
770 &self,
771 regions: &[VhostUserMemoryRegionInfo],
772 ) -> Result<(), vhost::Error> {
773 unsafe { (*self.regions.get()).extend_from_slice(regions) }
774 Ok(())
775 }
776 }
777
778 let vuh = VhostUserHandleImpl {
779 vu: MockFrontend {
780 regions: std::cell::UnsafeCell::new(vec![]),
781 },
782 socket_path: "".to_string(),
783 };
784
785 let region_size = 0x10000;
786 let file = TempFile::new().unwrap().into_file();
787 let file_size = 2 * region_size;
788 file.set_len(file_size as u64).unwrap();
789 let regions = vec![
790 (GuestAddress(0x0), region_size),
791 (GuestAddress(0x10000), region_size),
792 ];
793
794 let guest_memory = create_mem(file, ®ions);
795
796 vuh.update_mem_table(&guest_memory).unwrap();
797
798 let expected_regions = guest_memory
800 .iter()
801 .map(|region| VhostUserMemoryRegionInfo {
802 guest_phys_addr: region.start_addr().raw_value(),
803 memory_size: region.len(),
804 userspace_addr: region.inner.as_ptr() as u64,
805 mmap_offset: region.file_offset().unwrap().start(),
806 mmap_handle: region.file_offset().unwrap().file().as_raw_fd(),
807 })
808 .collect::<Vec<_>>();
809
810 for (region, expected) in (unsafe { &*vuh.vu.regions.get() })
811 .iter()
812 .zip(expected_regions)
813 {
814 assert_eq!(region.guest_phys_addr, expected.guest_phys_addr);
816 assert_eq!(region.memory_size, expected.memory_size);
817 assert_eq!(region.userspace_addr, expected.userspace_addr);
818 assert_eq!(region.mmap_offset, expected.mmap_offset);
819 assert_eq!(region.mmap_handle, expected.mmap_handle);
820 }
821 }
822
823 #[test]
824 fn test_setup_backend() {
825 #[derive(Default)]
826 struct VringData {
827 index: usize,
828 size: u16,
829 config: VringConfigData,
830 base: u16,
831 call: i32,
832 kick: i32,
833 enable: bool,
834 }
835
836 struct MockFrontend {
837 vrings: std::cell::UnsafeCell<Vec<VringData>>,
838 }
839
840 impl VhostUserHandleBackend for MockFrontend {
841 fn set_mem_table(
842 &self,
843 _regions: &[VhostUserMemoryRegionInfo],
844 ) -> Result<(), vhost::Error> {
845 Ok(())
846 }
847
848 fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<(), vhost::Error> {
849 unsafe {
850 (*self.vrings.get()).push(VringData {
851 index: queue_index,
852 size: num,
853 ..Default::default()
854 })
855 };
856 Ok(())
857 }
858
859 fn set_vring_addr(
860 &self,
861 queue_index: usize,
862 config_data: &VringConfigData,
863 ) -> Result<(), vhost::Error> {
864 unsafe { (&mut (*self.vrings.get()))[queue_index].config = *config_data };
865 Ok(())
866 }
867
868 fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<(), vhost::Error> {
869 unsafe { (&mut (*self.vrings.get()))[queue_index].base = base };
870 Ok(())
871 }
872
873 fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<(), vhost::Error> {
874 unsafe { (&mut (*self.vrings.get()))[queue_index].call = fd.as_raw_fd() };
875 Ok(())
876 }
877
878 fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<(), vhost::Error> {
879 unsafe { (&mut (*self.vrings.get()))[queue_index].kick = fd.as_raw_fd() };
880 Ok(())
881 }
882
883 fn set_vring_enable(
884 &mut self,
885 queue_index: usize,
886 enable: bool,
887 ) -> Result<(), vhost::Error> {
888 unsafe { &mut *self.vrings.get() }
889 .get_mut(queue_index)
890 .unwrap()
891 .enable = enable;
892 Ok(())
893 }
894 }
895
896 let mut vuh = VhostUserHandleImpl {
897 vu: MockFrontend {
898 vrings: std::cell::UnsafeCell::new(vec![]),
899 },
900 socket_path: "".to_string(),
901 };
902
903 let region_size = 0x10000;
904 let file = TempFile::new().unwrap().into_file();
905 file.set_len(region_size as u64).unwrap();
906 let regions = vec![(GuestAddress(0x0), region_size)];
907
908 let guest_memory = create_mem(file, ®ions);
909
910 let mut queue = Queue::new(128);
911 queue.ready = true;
912 queue.size = queue.max_size;
913 queue.initialize(&guest_memory).unwrap();
914
915 let event_fd = EventFd::new(0).unwrap();
916
917 let queues = [(0, &queue, &event_fd)];
918
919 let interrupt = default_interrupt();
920 vuh.setup_backend(&guest_memory, &queues, interrupt.clone())
921 .unwrap();
922
923 let expected_config = VringData {
926 index: 0,
927 size: 128,
928 config: VringConfigData {
929 queue_max_size: 128,
930 queue_size: 128,
931 flags: 0,
932 desc_table_addr: guest_memory
933 .get_host_address(queue.desc_table_address)
934 .unwrap() as u64,
935 used_ring_addr: guest_memory
936 .get_host_address(queue.used_ring_address)
937 .unwrap() as u64,
938 avail_ring_addr: guest_memory
939 .get_host_address(queue.avail_ring_address)
940 .unwrap() as u64,
941 log_addr: None,
942 },
943 base: queue.avail_ring_idx_get(),
944 call: interrupt
945 .notifier(VirtioInterruptType::Queue(0u16))
946 .as_ref()
947 .unwrap()
948 .as_raw_fd(),
949 kick: event_fd.as_raw_fd(),
950 enable: true,
951 };
952
953 let result = unsafe { &*vuh.vu.vrings.get() };
954 assert_eq!(result.len(), 1);
955 assert_eq!(result[0].index, expected_config.index);
956 assert_eq!(result[0].size, expected_config.size);
957
958 assert_eq!(
960 result[0].config.queue_max_size,
961 expected_config.config.queue_max_size
962 );
963 assert_eq!(
964 result[0].config.queue_size,
965 expected_config.config.queue_size
966 );
967 assert_eq!(result[0].config.flags, expected_config.config.flags);
968 assert_eq!(
969 result[0].config.desc_table_addr,
970 expected_config.config.desc_table_addr
971 );
972 assert_eq!(
973 result[0].config.used_ring_addr,
974 expected_config.config.used_ring_addr
975 );
976 assert_eq!(
977 result[0].config.avail_ring_addr,
978 expected_config.config.avail_ring_addr
979 );
980 assert_eq!(result[0].config.log_addr, expected_config.config.log_addr);
981
982 assert_eq!(result[0].base, expected_config.base);
983 assert_eq!(result[0].call, expected_config.call);
984 assert_eq!(result[0].kick, expected_config.kick);
985 assert_eq!(result[0].enable, expected_config.enable);
986 }
987}