vmm/devices/virtio/
vhost_user.rs

1// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4// Portions Copyright 2019 Intel Corporation. All Rights Reserved.
5// SPDX-License-Identifier: Apache-2.0
6
7use std::os::fd::AsRawFd;
8use std::os::unix::net::UnixStream;
9use std::sync::Arc;
10
11use vhost::vhost_user::message::*;
12use vhost::vhost_user::{Frontend, VhostUserFrontend};
13use vhost::{Error as VhostError, VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
14use vm_memory::{Address, GuestMemory, GuestMemoryError, GuestMemoryRegion};
15use vmm_sys_util::eventfd::EventFd;
16
17use crate::devices::virtio::queue::Queue;
18use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType};
19use crate::vstate::memory::GuestMemoryMmap;
20
21/// vhost-user error.
22#[derive(Debug, thiserror::Error, displaydoc::Display)]
23pub enum VhostUserError {
24    /// Invalid available address
25    AvailAddress(GuestMemoryError),
26    /// Failed to connect to UDS Unix stream: {0}
27    Connect(#[from] std::io::Error),
28    /// Invalid descriptor table address
29    DescriptorTableAddress(GuestMemoryError),
30    /// Get features failed: {0}
31    VhostUserGetFeatures(VhostError),
32    /// Get protocol features failed: {0}
33    VhostUserGetProtocolFeatures(VhostError),
34    /// Set owner failed: {0}
35    VhostUserSetOwner(VhostError),
36    /// Set features failed: {0}
37    VhostUserSetFeatures(VhostError),
38    /// Set protocol features failed: {0}
39    VhostUserSetProtocolFeatures(VhostError),
40    /// Set mem table failed: {0}
41    VhostUserSetMemTable(VhostError),
42    /// Set vring num failed: {0}
43    VhostUserSetVringNum(VhostError),
44    /// Set vring addr failed: {0}
45    VhostUserSetVringAddr(VhostError),
46    /// Set vring base failed: {0}
47    VhostUserSetVringBase(VhostError),
48    /// Set vring call failed: {0}
49    VhostUserSetVringCall(VhostError),
50    /// Set vring kick failed: {0}
51    VhostUserSetVringKick(VhostError),
52    /// Set vring enable failed: {0}
53    VhostUserSetVringEnable(VhostError),
54    /// Failed to read vhost eventfd: No memory region found
55    VhostUserNoMemoryRegion,
56    /// Invalid used address
57    UsedAddress(GuestMemoryError),
58}
59
60// Trait with all methods we use from `Frontend` from vhost crate.
61// It allows us to create a mock implementation of the `Frontend`
62// to verify calls to the backend.
63// All methods have default impl in order to simplify mock impls.
64pub trait VhostUserHandleBackend: Sized {
65    /// Constructor of `Frontend`
66    fn from_stream(_sock: UnixStream, _max_queue_num: u64) -> Self {
67        unimplemented!()
68    }
69
70    fn set_hdr_flags(&self, _flags: VhostUserHeaderFlag) {
71        unimplemented!()
72    }
73
74    /// Get from the underlying vhost implementation the feature bitmask.
75    fn get_features(&self) -> Result<u64, vhost::Error> {
76        unimplemented!()
77    }
78
79    /// Enable features in the underlying vhost implementation using a bitmask.
80    fn set_features(&self, _features: u64) -> Result<(), vhost::Error> {
81        unimplemented!()
82    }
83
84    /// Set the current Frontend as an owner of the session.
85    fn set_owner(&self) -> Result<(), vhost::Error> {
86        unimplemented!()
87    }
88
89    /// Set the memory map regions on the slave so it can translate the vring
90    /// addresses. In the ancillary data there is an array of file descriptors
91    fn set_mem_table(&self, _regions: &[VhostUserMemoryRegionInfo]) -> Result<(), vhost::Error> {
92        unimplemented!()
93    }
94
95    /// Set the size of the queue.
96    fn set_vring_num(&self, _queue_index: usize, _num: u16) -> Result<(), vhost::Error> {
97        unimplemented!()
98    }
99
100    /// Sets the addresses of the different aspects of the vring.
101    fn set_vring_addr(
102        &self,
103        _queue_index: usize,
104        _config_data: &VringConfigData,
105    ) -> Result<(), vhost::Error> {
106        unimplemented!()
107    }
108
109    /// Sets the base offset in the available vring.
110    fn set_vring_base(&self, _queue_index: usize, _base: u16) -> Result<(), vhost::Error> {
111        unimplemented!()
112    }
113
114    /// Set the event file descriptor to signal when buffers are used.
115    /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
116    /// is set when there is no file descriptor in the ancillary data. This signals that polling
117    /// will be used instead of waiting for the call.
118    fn set_vring_call(&self, _queue_index: usize, _fd: &EventFd) -> Result<(), vhost::Error> {
119        unimplemented!()
120    }
121
122    /// Set the event file descriptor for adding buffers to the vring.
123    /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
124    /// is set when there is no file descriptor in the ancillary data. This signals that polling
125    /// should be used instead of waiting for a kick.
126    fn set_vring_kick(&self, _queue_index: usize, _fd: &EventFd) -> Result<(), vhost::Error> {
127        unimplemented!()
128    }
129
130    fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures, vhost::Error> {
131        unimplemented!()
132    }
133
134    fn set_protocol_features(
135        &mut self,
136        _features: VhostUserProtocolFeatures,
137    ) -> Result<(), vhost::Error> {
138        unimplemented!()
139    }
140
141    fn set_vring_enable(&mut self, _queue_index: usize, _enable: bool) -> Result<(), vhost::Error> {
142        unimplemented!()
143    }
144
145    fn get_config(
146        &mut self,
147        _offset: u32,
148        _size: u32,
149        _flags: VhostUserConfigFlags,
150        _buf: &[u8],
151    ) -> Result<(VhostUserConfig, VhostUserConfigPayload), vhost::Error> {
152        unimplemented!()
153    }
154
155    fn set_config(
156        &mut self,
157        _offset: u32,
158        _flags: VhostUserConfigFlags,
159        _buf: &[u8],
160    ) -> Result<(), vhost::Error> {
161        unimplemented!()
162    }
163}
164
165impl VhostUserHandleBackend for Frontend {
166    fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
167        Frontend::from_stream(sock, max_queue_num)
168    }
169
170    fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) {
171        self.set_hdr_flags(flags)
172    }
173
174    /// Get from the underlying vhost implementation the feature bitmask.
175    fn get_features(&self) -> Result<u64, vhost::Error> {
176        <Frontend as VhostBackend>::get_features(self)
177    }
178
179    /// Enable features in the underlying vhost implementation using a bitmask.
180    fn set_features(&self, features: u64) -> Result<(), vhost::Error> {
181        <Frontend as VhostBackend>::set_features(self, features)
182    }
183
184    /// Set the current Frontend as an owner of the session.
185    fn set_owner(&self) -> Result<(), vhost::Error> {
186        <Frontend as VhostBackend>::set_owner(self)
187    }
188
189    /// Set the memory map regions on the slave so it can translate the vring
190    /// addresses. In the ancillary data there is an array of file descriptors
191    fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<(), vhost::Error> {
192        <Frontend as VhostBackend>::set_mem_table(self, regions)
193    }
194
195    /// Set the size of the queue.
196    fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<(), vhost::Error> {
197        <Frontend as VhostBackend>::set_vring_num(self, queue_index, num)
198    }
199
200    /// Sets the addresses of the different aspects of the vring.
201    fn set_vring_addr(
202        &self,
203        queue_index: usize,
204        config_data: &VringConfigData,
205    ) -> Result<(), vhost::Error> {
206        <Frontend as VhostBackend>::set_vring_addr(self, queue_index, config_data)
207    }
208
209    /// Sets the base offset in the available vring.
210    fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<(), vhost::Error> {
211        <Frontend as VhostBackend>::set_vring_base(self, queue_index, base)
212    }
213
214    /// Set the event file descriptor to signal when buffers are used.
215    /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
216    /// is set when there is no file descriptor in the ancillary data. This signals that polling
217    /// will be used instead of waiting for the call.
218    fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<(), vhost::Error> {
219        <Frontend as VhostBackend>::set_vring_call(self, queue_index, fd)
220    }
221
222    /// Set the event file descriptor for adding buffers to the vring.
223    /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
224    /// is set when there is no file descriptor in the ancillary data. This signals that polling
225    /// should be used instead of waiting for a kick.
226    fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<(), vhost::Error> {
227        <Frontend as VhostBackend>::set_vring_kick(self, queue_index, fd)
228    }
229
230    fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures, vhost::Error> {
231        <Frontend as VhostUserFrontend>::get_protocol_features(self)
232    }
233
234    fn set_protocol_features(
235        &mut self,
236        features: VhostUserProtocolFeatures,
237    ) -> Result<(), vhost::Error> {
238        <Frontend as VhostUserFrontend>::set_protocol_features(self, features)
239    }
240
241    fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<(), vhost::Error> {
242        <Frontend as VhostUserFrontend>::set_vring_enable(self, queue_index, enable)
243    }
244
245    fn get_config(
246        &mut self,
247        offset: u32,
248        size: u32,
249        flags: VhostUserConfigFlags,
250        buf: &[u8],
251    ) -> Result<(VhostUserConfig, VhostUserConfigPayload), vhost::Error> {
252        <Frontend as VhostUserFrontend>::get_config(self, offset, size, flags, buf)
253    }
254
255    fn set_config(
256        &mut self,
257        offset: u32,
258        flags: VhostUserConfigFlags,
259        buf: &[u8],
260    ) -> Result<(), vhost::Error> {
261        <Frontend as VhostUserFrontend>::set_config(self, offset, flags, buf)
262    }
263}
264
265pub type VhostUserHandle = VhostUserHandleImpl<Frontend>;
266
267/// vhost-user socket handle
268#[derive(Clone)]
269pub struct VhostUserHandleImpl<T: VhostUserHandleBackend> {
270    pub vu: T,
271    pub socket_path: String,
272}
273
274impl<T: VhostUserHandleBackend> std::fmt::Debug for VhostUserHandleImpl<T> {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        f.debug_struct("VhostUserHandle")
277            .field("socket_path", &self.socket_path)
278            .finish()
279    }
280}
281
282impl<T: VhostUserHandleBackend> VhostUserHandleImpl<T> {
283    /// Connect to the vhost-user backend socket and mark self as an
284    /// owner of the session.
285    pub fn new(socket_path: &str, num_queues: u64) -> Result<Self, VhostUserError> {
286        let stream = UnixStream::connect(socket_path).map_err(VhostUserError::Connect)?;
287
288        let vu = T::from_stream(stream, num_queues);
289        vu.set_owner().map_err(VhostUserError::VhostUserSetOwner)?;
290
291        Ok(Self {
292            vu,
293            socket_path: socket_path.to_string(),
294        })
295    }
296
297    /// Set vhost-user features to the backend.
298    pub fn set_features(&self, features: u64) -> Result<(), VhostUserError> {
299        self.vu
300            .set_features(features)
301            .map_err(VhostUserError::VhostUserSetFeatures)
302    }
303
304    /// Set vhost-user protocol features to the backend.
305    pub fn set_protocol_features(
306        &mut self,
307        acked_features: u64,
308        acked_protocol_features: u64,
309    ) -> Result<(), VhostUserError> {
310        if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0
311            && let Some(acked_protocol_features) =
312                VhostUserProtocolFeatures::from_bits(acked_protocol_features)
313        {
314            self.vu
315                .set_protocol_features(acked_protocol_features)
316                .map_err(VhostUserError::VhostUserSetProtocolFeatures)?;
317
318            if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
319                self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
320            }
321        }
322
323        Ok(())
324    }
325
326    /// Negotiate virtio and protocol features with the backend.
327    pub fn negotiate_features(
328        &mut self,
329        avail_features: u64,
330        avail_protocol_features: VhostUserProtocolFeatures,
331    ) -> Result<(u64, u64), VhostUserError> {
332        // Get features from backend, do negotiation to get a feature collection which
333        // both VMM and backend support.
334        let backend_features = self
335            .vu
336            .get_features()
337            .map_err(VhostUserError::VhostUserGetFeatures)?;
338        let acked_features = avail_features & backend_features;
339
340        let acked_protocol_features =
341            // If frontend can negotiate protocol features.
342            if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
343                let backend_protocol_features = self
344                    .vu
345                    .get_protocol_features()
346                    .map_err(VhostUserError::VhostUserGetProtocolFeatures)?;
347
348                let acked_protocol_features = avail_protocol_features & backend_protocol_features;
349
350                self.vu
351                    .set_protocol_features(acked_protocol_features)
352                    .map_err(VhostUserError::VhostUserSetProtocolFeatures)?;
353
354                acked_protocol_features
355            } else {
356                VhostUserProtocolFeatures::empty()
357            };
358
359        if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) {
360            self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
361        }
362
363        Ok((acked_features, acked_protocol_features.bits()))
364    }
365
366    /// Update guest memory table to the backend.
367    fn update_mem_table(&self, mem: &GuestMemoryMmap) -> Result<(), VhostUserError> {
368        let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
369
370        for region in mem.iter() {
371            let (mmap_handle, mmap_offset) = match region.file_offset() {
372                Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()),
373                None => {
374                    return Err(VhostUserError::VhostUserNoMemoryRegion);
375                }
376            };
377
378            let vhost_user_net_reg = VhostUserMemoryRegionInfo {
379                guest_phys_addr: region.start_addr().raw_value(),
380                memory_size: region.len(),
381                userspace_addr: region.inner.as_ptr() as u64,
382                mmap_offset,
383                mmap_handle,
384            };
385            regions.push(vhost_user_net_reg);
386        }
387
388        self.vu
389            .set_mem_table(regions.as_slice())
390            .map_err(VhostUserError::VhostUserSetMemTable)?;
391
392        Ok(())
393    }
394
395    /// Set up vhost-user backend. This includes updating memory table,
396    /// sending information about virtio rings and enabling them.
397    pub fn setup_backend(
398        &mut self,
399        mem: &GuestMemoryMmap,
400        queues: &[(usize, &Queue, &EventFd)],
401        interrupt: Arc<dyn VirtioInterrupt>,
402    ) -> Result<(), VhostUserError> {
403        // Provide the memory table to the backend.
404        self.update_mem_table(mem)?;
405
406        // Send set_vring_num here, since it could tell backends, like SPDK,
407        // how many virt queues to be handled, which backend required to know
408        // at early stage.
409        for (queue_index, queue, _) in queues.iter() {
410            self.vu
411                .set_vring_num(*queue_index, queue.size)
412                .map_err(VhostUserError::VhostUserSetVringNum)?;
413        }
414
415        for (queue_index, queue, queue_evt) in queues.iter() {
416            let config_data = VringConfigData {
417                queue_max_size: queue.max_size,
418                queue_size: queue.size,
419                flags: 0u32,
420                desc_table_addr: mem
421                    .get_host_address(queue.desc_table_address)
422                    .map_err(VhostUserError::DescriptorTableAddress)?
423                    as u64,
424                used_ring_addr: mem
425                    .get_host_address(queue.used_ring_address)
426                    .map_err(VhostUserError::UsedAddress)? as u64,
427                avail_ring_addr: mem
428                    .get_host_address(queue.avail_ring_address)
429                    .map_err(VhostUserError::AvailAddress)? as u64,
430                log_addr: None,
431            };
432
433            self.vu
434                .set_vring_addr(*queue_index, &config_data)
435                .map_err(VhostUserError::VhostUserSetVringAddr)?;
436            self.vu
437                .set_vring_base(*queue_index, queue.avail_ring_idx_get())
438                .map_err(VhostUserError::VhostUserSetVringBase)?;
439
440            // No matter the queue, we set irq_evt for signaling the guest that buffers were
441            // consumed.
442            self.vu
443                .set_vring_call(
444                    *queue_index,
445                    interrupt
446                        .notifier(VirtioInterruptType::Queue(
447                            (*queue_index).try_into().unwrap_or_else(|_| {
448                                panic!("vhost-user: invalid queue index: {}", *queue_index)
449                            }),
450                        ))
451                        .as_ref()
452                        .unwrap(),
453                )
454                .map_err(VhostUserError::VhostUserSetVringCall)?;
455
456            self.vu
457                .set_vring_kick(*queue_index, queue_evt)
458                .map_err(VhostUserError::VhostUserSetVringKick)?;
459
460            self.vu
461                .set_vring_enable(*queue_index, true)
462                .map_err(VhostUserError::VhostUserSetVringEnable)?;
463        }
464
465        Ok(())
466    }
467}
468
469#[cfg(test)]
470pub(crate) mod tests {
471    #![allow(clippy::undocumented_unsafe_blocks)]
472
473    use std::fs::File;
474
475    use vmm_sys_util::tempfile::TempFile;
476
477    use super::*;
478    use crate::devices::virtio::test_utils::default_interrupt;
479    use crate::test_utils::create_tmp_socket;
480    use crate::vstate::memory;
481    use crate::vstate::memory::{GuestAddress, GuestRegionMmapExt};
482
483    pub(crate) fn create_mem(file: File, regions: &[(GuestAddress, usize)]) -> GuestMemoryMmap {
484        GuestMemoryMmap::from_regions(
485            memory::create(
486                regions.iter().copied(),
487                libc::MAP_PRIVATE,
488                Some(file),
489                false,
490            )
491            .unwrap()
492            .into_iter()
493            .map(|region| GuestRegionMmapExt::dram_from_mmap_region(region, 0))
494            .collect(),
495        )
496        .unwrap()
497    }
498
499    #[test]
500    fn test_new() {
501        struct MockFrontend {
502            sock: UnixStream,
503            max_queue_num: u64,
504            is_owner: std::cell::UnsafeCell<bool>,
505        }
506
507        impl VhostUserHandleBackend for MockFrontend {
508            fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
509                Self {
510                    sock,
511                    max_queue_num,
512                    is_owner: std::cell::UnsafeCell::new(false),
513                }
514            }
515
516            fn set_owner(&self) -> Result<(), vhost::Error> {
517                unsafe { *self.is_owner.get() = true };
518                Ok(())
519            }
520        }
521
522        let max_queue_num = 69;
523
524        let (_tmp_dir, tmp_socket_path) = create_tmp_socket();
525
526        // Creation of the VhostUserHandleImpl correctly connects to the socket, sets the maximum
527        // number of queues and sets itself as an owner of the session.
528        let vuh =
529            VhostUserHandleImpl::<MockFrontend>::new(&tmp_socket_path, max_queue_num).unwrap();
530        assert_eq!(
531            vuh.vu
532                .sock
533                .peer_addr()
534                .unwrap()
535                .as_pathname()
536                .unwrap()
537                .to_str()
538                .unwrap(),
539            &tmp_socket_path,
540        );
541        assert_eq!(vuh.vu.max_queue_num, max_queue_num);
542        assert!(unsafe { *vuh.vu.is_owner.get() });
543    }
544
545    #[test]
546    fn test_set_features() {
547        struct MockFrontend {
548            features: std::cell::UnsafeCell<u64>,
549        }
550
551        impl VhostUserHandleBackend for MockFrontend {
552            fn set_features(&self, features: u64) -> Result<(), vhost::Error> {
553                unsafe { *self.features.get() = features };
554                Ok(())
555            }
556        }
557
558        // VhostUserHandleImpl can correctly set backend features.
559        let vuh = VhostUserHandleImpl {
560            vu: MockFrontend { features: 0.into() },
561            socket_path: "".to_string(),
562        };
563        vuh.set_features(0x69).unwrap();
564        assert_eq!(unsafe { *vuh.vu.features.get() }, 0x69);
565    }
566
567    #[test]
568    fn test_set_protocol_features() {
569        struct MockFrontend {
570            protocol_features: VhostUserProtocolFeatures,
571            hdr_flags: std::cell::UnsafeCell<VhostUserHeaderFlag>,
572        }
573
574        impl VhostUserHandleBackend for MockFrontend {
575            fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) {
576                unsafe { *self.hdr_flags.get() = flags };
577            }
578
579            fn set_protocol_features(
580                &mut self,
581                features: VhostUserProtocolFeatures,
582            ) -> Result<(), vhost::Error> {
583                self.protocol_features = features;
584                Ok(())
585            }
586        }
587
588        let mut vuh = VhostUserHandleImpl {
589            vu: MockFrontend {
590                protocol_features: VhostUserProtocolFeatures::empty(),
591                hdr_flags: std::cell::UnsafeCell::new(VhostUserHeaderFlag::empty()),
592            },
593            socket_path: "".to_string(),
594        };
595
596        // No protocol features are set if acked_features do not have PROTOCOL_FEATURES bit
597        let acked_features = 0;
598        let acked_protocol_features = VhostUserProtocolFeatures::empty();
599        vuh.set_protocol_features(acked_features, acked_protocol_features.bits())
600            .unwrap();
601        assert_eq!(vuh.vu.protocol_features, VhostUserProtocolFeatures::empty());
602        assert_eq!(
603            unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
604            VhostUserHeaderFlag::empty().bits()
605        );
606
607        // No protocol features are set if acked_features do not have PROTOCOL_FEATURES bit
608        let acked_features = 0;
609        let acked_protocol_features = VhostUserProtocolFeatures::all();
610        vuh.set_protocol_features(acked_features, acked_protocol_features.bits())
611            .unwrap();
612        assert_eq!(vuh.vu.protocol_features, VhostUserProtocolFeatures::empty());
613        assert_eq!(
614            unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
615            VhostUserHeaderFlag::empty().bits()
616        );
617
618        // If not REPLY_ACK present, no header is set
619        let acked_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
620        let mut acked_protocol_features = VhostUserProtocolFeatures::all();
621        acked_protocol_features.set(VhostUserProtocolFeatures::REPLY_ACK, false);
622        vuh.set_protocol_features(acked_features, acked_protocol_features.bits())
623            .unwrap();
624        assert_eq!(vuh.vu.protocol_features, acked_protocol_features);
625        assert_eq!(
626            unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
627            VhostUserHeaderFlag::empty().bits()
628        );
629
630        // If REPLY_ACK present, header is set
631        let acked_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
632        let acked_protocol_features = VhostUserProtocolFeatures::all();
633        vuh.set_protocol_features(acked_features, acked_protocol_features.bits())
634            .unwrap();
635        assert_eq!(vuh.vu.protocol_features, acked_protocol_features);
636        assert_eq!(
637            unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
638            VhostUserHeaderFlag::NEED_REPLY.bits()
639        );
640    }
641
642    #[test]
643    fn test_negotiate_features() {
644        struct MockFrontend {
645            features: u64,
646            protocol_features: VhostUserProtocolFeatures,
647            hdr_flags: std::cell::UnsafeCell<VhostUserHeaderFlag>,
648        }
649
650        impl VhostUserHandleBackend for MockFrontend {
651            fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) {
652                unsafe { *self.hdr_flags.get() = flags };
653            }
654
655            fn get_features(&self) -> Result<u64, vhost::Error> {
656                Ok(self.features)
657            }
658
659            fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures, vhost::Error> {
660                Ok(self.protocol_features)
661            }
662
663            fn set_protocol_features(
664                &mut self,
665                features: VhostUserProtocolFeatures,
666            ) -> Result<(), vhost::Error> {
667                self.protocol_features = features;
668                Ok(())
669            }
670        }
671
672        let mut vuh = VhostUserHandleImpl {
673            vu: MockFrontend {
674                features: 0,
675                protocol_features: VhostUserProtocolFeatures::empty(),
676                hdr_flags: std::cell::UnsafeCell::new(VhostUserHeaderFlag::empty()),
677            },
678            socket_path: "".to_string(),
679        };
680
681        // If nothing is available, nothing is negotiated
682        let avail_features = 0;
683        let avail_protocol_features = VhostUserProtocolFeatures::empty();
684        let (acked_features, acked_protocol_features) = vuh
685            .negotiate_features(avail_features, avail_protocol_features)
686            .unwrap();
687        assert_eq!(acked_features, avail_features);
688        assert_eq!(acked_protocol_features, avail_protocol_features.bits());
689        assert_eq!(vuh.vu.protocol_features, VhostUserProtocolFeatures::empty());
690        assert_eq!(
691            unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
692            VhostUserHeaderFlag::empty().bits()
693        );
694
695        // If neither frontend avail_features nor backend avail_features contain PROTOCOL_FEATURES
696        // bit, only features are negotiated
697        let mut avail_features = VhostUserVirtioFeatures::all();
698        avail_features.set(VhostUserVirtioFeatures::PROTOCOL_FEATURES, false);
699
700        // Pretend backend has same features as frontend
701        vuh.vu.features = avail_features.bits();
702
703        let avail_protocol_features = VhostUserProtocolFeatures::empty();
704        let (acked_features, acked_protocol_features) = vuh
705            .negotiate_features(avail_features.bits(), avail_protocol_features)
706            .unwrap();
707        assert_eq!(acked_features, avail_features.bits());
708        assert_eq!(acked_protocol_features, avail_protocol_features.bits());
709        assert_eq!(vuh.vu.protocol_features, VhostUserProtocolFeatures::empty());
710        assert_eq!(
711            unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
712            VhostUserHeaderFlag::empty().bits()
713        );
714
715        // If PROTOCOL_FEATURES is negotiated, but REPLY_ACK is not, headers are not set
716        let avail_features = VhostUserVirtioFeatures::all();
717        // Pretend backend has same features as frontend
718        vuh.vu.features = avail_features.bits();
719
720        let mut avail_protocol_features = VhostUserProtocolFeatures::empty();
721        avail_protocol_features.set(VhostUserProtocolFeatures::CONFIG, true);
722
723        let mut backend_protocol_features = VhostUserProtocolFeatures::empty();
724        backend_protocol_features.set(VhostUserProtocolFeatures::CONFIG, true);
725        backend_protocol_features.set(VhostUserProtocolFeatures::PAGEFAULT, true);
726        vuh.vu.protocol_features = backend_protocol_features;
727
728        let (acked_features, acked_protocol_features) = vuh
729            .negotiate_features(avail_features.bits(), avail_protocol_features)
730            .unwrap();
731        assert_eq!(acked_features, avail_features.bits());
732        assert_eq!(acked_protocol_features, avail_protocol_features.bits());
733        assert_eq!(vuh.vu.protocol_features, avail_protocol_features);
734        assert_eq!(
735            unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
736            VhostUserHeaderFlag::empty().bits()
737        );
738
739        // If PROTOCOL_FEATURES and REPLY_ACK are negotiated
740        let avail_features = VhostUserVirtioFeatures::all();
741        // Pretend backend has same features as frontend
742        vuh.vu.features = avail_features.bits();
743
744        let mut avail_protocol_features = VhostUserProtocolFeatures::empty();
745        avail_protocol_features.set(VhostUserProtocolFeatures::REPLY_ACK, true);
746
747        // Pretend backend has same features as frontend
748        vuh.vu.protocol_features = avail_protocol_features;
749
750        let (acked_features, acked_protocol_features) = vuh
751            .negotiate_features(avail_features.bits(), avail_protocol_features)
752            .unwrap();
753        assert_eq!(acked_features, avail_features.bits());
754        assert_eq!(acked_protocol_features, avail_protocol_features.bits());
755        assert_eq!(vuh.vu.protocol_features, avail_protocol_features);
756        assert_eq!(
757            unsafe { &*vuh.vu.hdr_flags.get() }.bits(),
758            VhostUserHeaderFlag::NEED_REPLY.bits(),
759        );
760    }
761
762    #[test]
763    fn test_update_mem_table() {
764        struct MockFrontend {
765            regions: std::cell::UnsafeCell<Vec<VhostUserMemoryRegionInfo>>,
766        }
767
768        impl VhostUserHandleBackend for MockFrontend {
769            fn set_mem_table(
770                &self,
771                regions: &[VhostUserMemoryRegionInfo],
772            ) -> Result<(), vhost::Error> {
773                unsafe { (*self.regions.get()).extend_from_slice(regions) }
774                Ok(())
775            }
776        }
777
778        let vuh = VhostUserHandleImpl {
779            vu: MockFrontend {
780                regions: std::cell::UnsafeCell::new(vec![]),
781            },
782            socket_path: "".to_string(),
783        };
784
785        let region_size = 0x10000;
786        let file = TempFile::new().unwrap().into_file();
787        let file_size = 2 * region_size;
788        file.set_len(file_size as u64).unwrap();
789        let regions = vec![
790            (GuestAddress(0x0), region_size),
791            (GuestAddress(0x10000), region_size),
792        ];
793
794        let guest_memory = create_mem(file, &regions);
795
796        vuh.update_mem_table(&guest_memory).unwrap();
797
798        // VhostUserMemoryRegionInfo should be correctly set by the VhostUserHandleImpl
799        let expected_regions = guest_memory
800            .iter()
801            .map(|region| VhostUserMemoryRegionInfo {
802                guest_phys_addr: region.start_addr().raw_value(),
803                memory_size: region.len(),
804                userspace_addr: region.inner.as_ptr() as u64,
805                mmap_offset: region.file_offset().unwrap().start(),
806                mmap_handle: region.file_offset().unwrap().file().as_raw_fd(),
807            })
808            .collect::<Vec<_>>();
809
810        for (region, expected) in (unsafe { &*vuh.vu.regions.get() })
811            .iter()
812            .zip(expected_regions)
813        {
814            // VhostUserMemoryRegionInfo does not implement Eq.
815            assert_eq!(region.guest_phys_addr, expected.guest_phys_addr);
816            assert_eq!(region.memory_size, expected.memory_size);
817            assert_eq!(region.userspace_addr, expected.userspace_addr);
818            assert_eq!(region.mmap_offset, expected.mmap_offset);
819            assert_eq!(region.mmap_handle, expected.mmap_handle);
820        }
821    }
822
823    #[test]
824    fn test_setup_backend() {
825        #[derive(Default)]
826        struct VringData {
827            index: usize,
828            size: u16,
829            config: VringConfigData,
830            base: u16,
831            call: i32,
832            kick: i32,
833            enable: bool,
834        }
835
836        struct MockFrontend {
837            vrings: std::cell::UnsafeCell<Vec<VringData>>,
838        }
839
840        impl VhostUserHandleBackend for MockFrontend {
841            fn set_mem_table(
842                &self,
843                _regions: &[VhostUserMemoryRegionInfo],
844            ) -> Result<(), vhost::Error> {
845                Ok(())
846            }
847
848            fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<(), vhost::Error> {
849                unsafe {
850                    (*self.vrings.get()).push(VringData {
851                        index: queue_index,
852                        size: num,
853                        ..Default::default()
854                    })
855                };
856                Ok(())
857            }
858
859            fn set_vring_addr(
860                &self,
861                queue_index: usize,
862                config_data: &VringConfigData,
863            ) -> Result<(), vhost::Error> {
864                unsafe { (&mut (*self.vrings.get()))[queue_index].config = *config_data };
865                Ok(())
866            }
867
868            fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<(), vhost::Error> {
869                unsafe { (&mut (*self.vrings.get()))[queue_index].base = base };
870                Ok(())
871            }
872
873            fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<(), vhost::Error> {
874                unsafe { (&mut (*self.vrings.get()))[queue_index].call = fd.as_raw_fd() };
875                Ok(())
876            }
877
878            fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<(), vhost::Error> {
879                unsafe { (&mut (*self.vrings.get()))[queue_index].kick = fd.as_raw_fd() };
880                Ok(())
881            }
882
883            fn set_vring_enable(
884                &mut self,
885                queue_index: usize,
886                enable: bool,
887            ) -> Result<(), vhost::Error> {
888                unsafe { &mut *self.vrings.get() }
889                    .get_mut(queue_index)
890                    .unwrap()
891                    .enable = enable;
892                Ok(())
893            }
894        }
895
896        let mut vuh = VhostUserHandleImpl {
897            vu: MockFrontend {
898                vrings: std::cell::UnsafeCell::new(vec![]),
899            },
900            socket_path: "".to_string(),
901        };
902
903        let region_size = 0x10000;
904        let file = TempFile::new().unwrap().into_file();
905        file.set_len(region_size as u64).unwrap();
906        let regions = vec![(GuestAddress(0x0), region_size)];
907
908        let guest_memory = create_mem(file, &regions);
909
910        let mut queue = Queue::new(128);
911        queue.ready = true;
912        queue.size = queue.max_size;
913        queue.initialize(&guest_memory).unwrap();
914
915        let event_fd = EventFd::new(0).unwrap();
916
917        let queues = [(0, &queue, &event_fd)];
918
919        let interrupt = default_interrupt();
920        vuh.setup_backend(&guest_memory, &queues, interrupt.clone())
921            .unwrap();
922
923        // VhostUserHandleImpl should correctly send memory and queues information to
924        // the backend.
925        let expected_config = VringData {
926            index: 0,
927            size: 128,
928            config: VringConfigData {
929                queue_max_size: 128,
930                queue_size: 128,
931                flags: 0,
932                desc_table_addr: guest_memory
933                    .get_host_address(queue.desc_table_address)
934                    .unwrap() as u64,
935                used_ring_addr: guest_memory
936                    .get_host_address(queue.used_ring_address)
937                    .unwrap() as u64,
938                avail_ring_addr: guest_memory
939                    .get_host_address(queue.avail_ring_address)
940                    .unwrap() as u64,
941                log_addr: None,
942            },
943            base: queue.avail_ring_idx_get(),
944            call: interrupt
945                .notifier(VirtioInterruptType::Queue(0u16))
946                .as_ref()
947                .unwrap()
948                .as_raw_fd(),
949            kick: event_fd.as_raw_fd(),
950            enable: true,
951        };
952
953        let result = unsafe { &*vuh.vu.vrings.get() };
954        assert_eq!(result.len(), 1);
955        assert_eq!(result[0].index, expected_config.index);
956        assert_eq!(result[0].size, expected_config.size);
957
958        // VringConfigData does not implement Eq.
959        assert_eq!(
960            result[0].config.queue_max_size,
961            expected_config.config.queue_max_size
962        );
963        assert_eq!(
964            result[0].config.queue_size,
965            expected_config.config.queue_size
966        );
967        assert_eq!(result[0].config.flags, expected_config.config.flags);
968        assert_eq!(
969            result[0].config.desc_table_addr,
970            expected_config.config.desc_table_addr
971        );
972        assert_eq!(
973            result[0].config.used_ring_addr,
974            expected_config.config.used_ring_addr
975        );
976        assert_eq!(
977            result[0].config.avail_ring_addr,
978            expected_config.config.avail_ring_addr
979        );
980        assert_eq!(result[0].config.log_addr, expected_config.config.log_addr);
981
982        assert_eq!(result[0].base, expected_config.base);
983        assert_eq!(result[0].call, expected_config.call);
984        assert_eq!(result[0].kick, expected_config.kick);
985        assert_eq!(result[0].enable, expected_config.enable);
986    }
987}