vmm/devices/virtio/vsock/unix/
muxer.rs

1// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4
5/// `VsockMuxer` is the device-facing component of the Unix domain sockets vsock backend. I.e.
6/// by implementing the `VsockBackend` trait, it abstracts away the gory details of translating
7/// between AF_VSOCK and AF_UNIX, and presents a clean interface to the rest of the vsock
8/// device model.
9///
10/// The vsock muxer has two main roles:
11/// 1. Vsock connection multiplexer: It's the muxer's job to create, manage, and terminate
12///    `VsockConnection` objects. The muxer also routes packets to their owning connections. It
13///    does so via a connection `HashMap`, keyed by what is basically a (host_port, guest_port)
14///    tuple. Vsock packet traffic needs to be inspected, in order to detect connection request
15///    packets (leading to the creation of a new connection), and connection reset packets
16///    (leading to the termination of an existing connection). All other packets, though, must
17///    belong to an existing connection and, as such, the muxer simply forwards them.
18/// 2. Event dispatcher There are three event categories that the vsock backend is interested
19///    it:
20///    1. A new host-initiated connection is ready to be accepted from the listening host Unix
21///       socket;
22///    2. Data is available for reading from a newly-accepted host-initiated connection (i.e.
23///       the host is ready to issue a vsock connection request, informing us of the
24///       destination port to which it wants to connect);
25///    3. Some event was triggered for a connected Unix socket, that belongs to a
26///       `VsockConnection`.
27///
28///  The muxer gets notified about all of these events, because, as a `VsockEpollListener`
29///  implementor, it gets to register a nested epoll FD into the main VMM epolling loop. All
30///  other pollable FDs are then registered under this nested epoll FD.
31///  To route all these events to their handlers, the muxer uses another `HashMap` object,
32///  mapping `RawFd`s to `EpollListener`s.
33use std::collections::{HashMap, HashSet};
34use std::fmt::Debug;
35use std::io::Read;
36use std::os::unix::io::{AsRawFd, RawFd};
37use std::os::unix::net::{UnixListener, UnixStream};
38
39use log::{debug, error, info, warn};
40use vmm_sys_util::epoll::{ControlOperation, Epoll, EpollEvent, EventSet};
41
42use super::super::csm::ConnState;
43use super::super::defs::uapi;
44use super::super::{VsockBackend, VsockChannel, VsockEpollListener, VsockError};
45use super::muxer_killq::MuxerKillQ;
46use super::muxer_rxq::MuxerRxQ;
47use super::{MuxerConnection, VsockUnixBackendError, defs};
48use crate::devices::virtio::vsock::metrics::METRICS;
49use crate::devices::virtio::vsock::packet::{VsockPacketRx, VsockPacketTx};
50use crate::logger::IncMetric;
51
52/// A unique identifier of a `MuxerConnection` object. Connections are stored in a hash map,
53/// keyed by a `ConnMapKey` object.
54#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
55pub struct ConnMapKey {
56    local_port: u32,
57    peer_port: u32,
58}
59
60/// A muxer RX queue item.
61#[derive(Clone, Copy, Debug)]
62pub enum MuxerRx {
63    /// The packet must be fetched from the connection identified by `ConnMapKey`.
64    ConnRx(ConnMapKey),
65    /// The muxer must produce an RST packet.
66    RstPkt { local_port: u32, peer_port: u32 },
67}
68
69/// An epoll listener, registered under the muxer's nested epoll FD.
70#[derive(Debug)]
71enum EpollListener {
72    /// The listener is a `MuxerConnection`, identified by `key`, and interested in the events
73    /// in `evset`. Since `MuxerConnection` implements `VsockEpollListener`, notifications will
74    /// be forwarded to the listener via `VsockEpollListener::notify()`.
75    Connection { key: ConnMapKey, evset: EventSet },
76    /// A listener interested in new host-initiated connections.
77    HostSock,
78    /// A listener interested in reading host `connect <port>` commands from a freshly
79    /// connected host socket.
80    LocalStream(UnixStream),
81}
82
83/// The vsock connection multiplexer.
84#[derive(Debug)]
85pub struct VsockMuxer {
86    /// Guest CID.
87    cid: u64,
88    /// A hash map used to store the active connections.
89    conn_map: HashMap<ConnMapKey, MuxerConnection>,
90    /// A hash map used to store epoll event listeners / handlers.
91    listener_map: HashMap<RawFd, EpollListener>,
92    /// The RX queue. Items in this queue are consumed by `VsockMuxer::recv_pkt()`, and
93    /// produced
94    /// - by `VsockMuxer::send_pkt()` (e.g. RST in response to a connection request packet); and
95    /// - in response to EPOLLIN events (e.g. data available to be read from an AF_UNIX socket).
96    rxq: MuxerRxQ,
97    /// A queue used for terminating connections that are taking too long to shut down.
98    killq: MuxerKillQ,
99    /// The Unix socket, through which host-initiated connections are accepted.
100    host_sock: UnixListener,
101    /// The file system path of the host-side Unix socket. This is used to figure out the path
102    /// to Unix sockets listening on specific ports. I.e. `"<this path>_<port number>"`.
103    pub(crate) host_sock_path: String,
104    /// The nested epoll event set, used to register epoll listeners.
105    epoll: Epoll,
106    /// A hash set used to keep track of used host-side (local) ports, in order to assign local
107    /// ports to host-initiated connections.
108    local_port_set: HashSet<u32>,
109    /// The last used host-side port.
110    local_port_last: u32,
111}
112
113impl VsockChannel for VsockMuxer {
114    /// Deliver a vsock packet to the guest vsock driver.
115    ///
116    /// Retuns:
117    /// - `Ok(())`: `pkt` has been successfully filled in; or
118    /// - `Err(VsockError::NoData)`: there was no available data with which to fill in the packet.
119    fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result<(), VsockError> {
120        // We'll look for instructions on how to build the RX packet in the RX queue. If the
121        // queue is empty, that doesn't necessarily mean we don't have any pending RX, since
122        // the queue might be out-of-sync. If that's the case, we'll attempt to sync it first,
123        // and then try to pop something out again.
124        if self.rxq.is_empty() && !self.rxq.is_synced() {
125            self.rxq = MuxerRxQ::from_conn_map(&self.conn_map);
126        }
127
128        while let Some(rx) = self.rxq.peek() {
129            let res = match rx {
130                // We need to build an RST packet, going from `local_port` to `peer_port`.
131                MuxerRx::RstPkt {
132                    local_port,
133                    peer_port,
134                } => {
135                    pkt.hdr
136                        .set_op(uapi::VSOCK_OP_RST)
137                        .set_src_cid(uapi::VSOCK_HOST_CID)
138                        .set_dst_cid(self.cid)
139                        .set_src_port(local_port)
140                        .set_dst_port(peer_port)
141                        .set_len(0)
142                        .set_type(uapi::VSOCK_TYPE_STREAM)
143                        .set_flags(0)
144                        .set_buf_alloc(0)
145                        .set_fwd_cnt(0);
146                    self.rxq.pop().unwrap();
147                    return Ok(());
148                }
149
150                // We'll defer building the packet to this connection, since it has something
151                // to say.
152                MuxerRx::ConnRx(key) => {
153                    let mut conn_res = Err(VsockError::NoData);
154                    let mut do_pop = true;
155                    self.apply_conn_mutation(key, |conn| {
156                        conn_res = conn.recv_pkt(pkt);
157                        do_pop = !conn.has_pending_rx();
158                    });
159                    if do_pop {
160                        self.rxq.pop().unwrap();
161                    }
162                    conn_res
163                }
164            };
165
166            if res.is_ok() {
167                // Inspect traffic, looking for RST packets, since that means we have to
168                // terminate and remove this connection from the active connection pool.
169                //
170                if pkt.hdr.op() == uapi::VSOCK_OP_RST {
171                    self.remove_connection(ConnMapKey {
172                        local_port: pkt.hdr.src_port(),
173                        peer_port: pkt.hdr.dst_port(),
174                    });
175                }
176
177                debug!("vsock muxer: RX pkt: {:?}", pkt.hdr);
178                return Ok(());
179            }
180        }
181
182        Err(VsockError::NoData)
183    }
184
185    /// Deliver a guest-generated packet to its destination in the vsock backend.
186    ///
187    /// This absorbs unexpected packets, handles RSTs (by dropping connections), and forwards
188    /// all the rest to their owning `MuxerConnection`.
189    ///
190    /// Returns:
191    /// always `Ok(())` - the packet has been consumed, and its virtio TX buffers can be
192    /// returned to the guest vsock driver.
193    fn send_pkt(&mut self, pkt: &VsockPacketTx) -> Result<(), VsockError> {
194        let conn_key = ConnMapKey {
195            local_port: pkt.hdr.dst_port(),
196            peer_port: pkt.hdr.src_port(),
197        };
198
199        debug!(
200            "vsock: muxer.send[rxq.len={}]: {:?}",
201            self.rxq.len(),
202            pkt.hdr
203        );
204
205        // If this packet has an unsupported type (!=stream), we must send back an RST.
206        //
207        if pkt.hdr.type_() != uapi::VSOCK_TYPE_STREAM {
208            self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port());
209            return Ok(());
210        }
211
212        // We don't know how to handle packets addressed to other CIDs. We only handle the host
213        // part of the guest - host communication here.
214        if pkt.hdr.dst_cid() != uapi::VSOCK_HOST_CID {
215            info!(
216                "vsock: dropping guest packet for unknown CID: {:?}",
217                pkt.hdr
218            );
219            return Ok(());
220        }
221
222        if !self.conn_map.contains_key(&conn_key) {
223            // This packet can't be routed to any active connection (based on its src and dst
224            // ports).  The only orphan / unroutable packets we know how to handle are
225            // connection requests.
226            if pkt.hdr.op() == uapi::VSOCK_OP_REQUEST {
227                // Oh, this is a connection request!
228                self.handle_peer_request_pkt(pkt);
229            } else {
230                // Send back an RST, to let the drive know we weren't expecting this packet.
231                self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port());
232            }
233            return Ok(());
234        }
235
236        // Right, we know where to send this packet, then (to `conn_key`).
237        // However, if this is an RST, we have to forcefully terminate the connection, so
238        // there's no point in forwarding it the packet.
239        if pkt.hdr.op() == uapi::VSOCK_OP_RST {
240            self.remove_connection(conn_key);
241            return Ok(());
242        }
243
244        // Alright, everything looks in order - forward this packet to its owning connection.
245        let mut res: Result<(), VsockError> = Ok(());
246        self.apply_conn_mutation(conn_key, |conn| {
247            res = conn.send_pkt(pkt);
248        });
249
250        res
251    }
252
253    /// Check if the muxer has any pending RX data, with which to fill a guest-provided RX
254    /// buffer.
255    fn has_pending_rx(&self) -> bool {
256        !self.rxq.is_empty() || !self.rxq.is_synced()
257    }
258}
259
260impl AsRawFd for VsockMuxer {
261    /// Get the FD to be registered for polling upstream (in the main VMM epoll loop, in this
262    /// case).
263    ///
264    /// This will be the muxer's nested epoll FD.
265    fn as_raw_fd(&self) -> RawFd {
266        self.epoll.as_raw_fd()
267    }
268}
269
270impl VsockEpollListener for VsockMuxer {
271    /// Get the epoll events to be polled upstream.
272    ///
273    /// Since the polled FD is a nested epoll FD, we're only interested in EPOLLIN events (i.e.
274    /// some event occurred on one of the FDs registered under our epoll FD).
275    fn get_polled_evset(&self) -> EventSet {
276        EventSet::IN
277    }
278
279    /// Notify the muxer about a pending event having occured under its nested epoll FD.
280    fn notify(&mut self, _: EventSet) {
281        let mut epoll_events = vec![EpollEvent::new(EventSet::empty(), 0); 32];
282        match self.epoll.wait(0, epoll_events.as_mut_slice()) {
283            Ok(ev_cnt) => {
284                for ev in &epoll_events[0..ev_cnt] {
285                    self.handle_event(
286                        ev.fd(),
287                        // It's ok to unwrap here, since the `epoll_events[i].events` is filled
288                        // in by `epoll::wait()`, and therefore contains only valid epoll
289                        // flags.
290                        EventSet::from_bits(ev.events).unwrap(),
291                    );
292                }
293            }
294            Err(err) => {
295                warn!("vsock: failed to consume muxer epoll event: {}", err);
296                METRICS.muxer_event_fails.inc();
297            }
298        }
299    }
300}
301
302impl VsockBackend for VsockMuxer {}
303
304impl VsockMuxer {
305    /// Muxer constructor.
306    pub fn new(cid: u64, host_sock_path: String) -> Result<Self, VsockUnixBackendError> {
307        // Open/bind on the host Unix socket, so we can accept host-initiated
308        // connections.
309        let host_sock = UnixListener::bind(&host_sock_path)
310            .and_then(|sock| sock.set_nonblocking(true).map(|_| sock))
311            .map_err(VsockUnixBackendError::UnixBind)?;
312
313        let mut muxer = Self {
314            cid,
315            host_sock,
316            host_sock_path,
317            epoll: Epoll::new().map_err(VsockUnixBackendError::EpollFdCreate)?,
318            rxq: MuxerRxQ::new(),
319            conn_map: HashMap::with_capacity(defs::MAX_CONNECTIONS),
320            listener_map: HashMap::with_capacity(defs::MAX_CONNECTIONS + 1),
321            killq: MuxerKillQ::new(),
322            local_port_last: (1u32 << 30) - 1,
323            local_port_set: HashSet::with_capacity(defs::MAX_CONNECTIONS),
324        };
325
326        // Listen on the host initiated socket, for incoming connections.
327        muxer.add_listener(muxer.host_sock.as_raw_fd(), EpollListener::HostSock)?;
328        Ok(muxer)
329    }
330
331    /// Return the file system path of the host-side Unix socket.
332    pub fn host_sock_path(&self) -> &str {
333        &self.host_sock_path
334    }
335
336    /// Handle/dispatch an epoll event to its listener.
337    fn handle_event(&mut self, fd: RawFd, event_set: EventSet) {
338        debug!(
339            "vsock: muxer processing event: fd={}, evset={:?}",
340            fd, event_set
341        );
342
343        match self.listener_map.get_mut(&fd) {
344            // This event needs to be forwarded to a `MuxerConnection` that is listening for
345            // it.
346            Some(EpollListener::Connection { key, evset: _ }) => {
347                let key_copy = *key;
348                // The handling of this event will most probably mutate the state of the
349                // receiving connection. We'll need to check for new pending RX, event set
350                // mutation, and all that, so we're wrapping the event delivery inside those
351                // checks.
352                self.apply_conn_mutation(key_copy, |conn| {
353                    conn.notify(event_set);
354                });
355            }
356
357            // A new host-initiated connection is ready to be accepted.
358            Some(EpollListener::HostSock) => {
359                if self.conn_map.len() == defs::MAX_CONNECTIONS {
360                    // If we're already maxed-out on connections, we'll just accept and
361                    // immediately discard this potentially new one.
362                    warn!("vsock: connection limit reached; refusing new host connection");
363                    self.host_sock.accept().map(|_| 0).unwrap_or(0);
364                    return;
365                }
366                self.host_sock
367                    .accept()
368                    .map_err(VsockUnixBackendError::UnixAccept)
369                    .and_then(|(stream, _)| {
370                        stream
371                            .set_nonblocking(true)
372                            .map(|_| stream)
373                            .map_err(VsockUnixBackendError::UnixAccept)
374                    })
375                    .and_then(|stream| {
376                        // Before forwarding this connection to a listening AF_VSOCK socket on
377                        // the guest side, we need to know the destination port. We'll read
378                        // that port from a "connect" command received on this socket, so the
379                        // next step is to ask to be notified the moment we can read from it.
380                        self.add_listener(stream.as_raw_fd(), EpollListener::LocalStream(stream))
381                    })
382                    .unwrap_or_else(|err| {
383                        warn!("vsock: unable to accept local connection: {:?}", err);
384                    });
385            }
386
387            // Data is ready to be read from a host-initiated connection. That would be the
388            // "connect" command that we're expecting.
389            Some(EpollListener::LocalStream(_)) => {
390                if let Some(EpollListener::LocalStream(mut stream)) = self.remove_listener(fd) {
391                    Self::read_local_stream_port(&mut stream)
392                        .map(|peer_port| (self.allocate_local_port(), peer_port))
393                        .and_then(|(local_port, peer_port)| {
394                            self.add_connection(
395                                ConnMapKey {
396                                    local_port,
397                                    peer_port,
398                                },
399                                MuxerConnection::new_local_init(
400                                    stream,
401                                    uapi::VSOCK_HOST_CID,
402                                    self.cid,
403                                    local_port,
404                                    peer_port,
405                                ),
406                            )
407                        })
408                        .unwrap_or_else(|err| {
409                            info!("vsock: error adding local-init connection: {:?}", err);
410                        })
411                }
412            }
413
414            _ => {
415                info!(
416                    "vsock: unexpected event: fd={:?}, evset={:?}",
417                    fd, event_set
418                );
419                METRICS.muxer_event_fails.inc();
420            }
421        }
422    }
423
424    /// Parse a host "connect" command, and extract the destination vsock port.
425    fn read_local_stream_port(stream: &mut UnixStream) -> Result<u32, VsockUnixBackendError> {
426        let mut buf = [0u8; 32];
427
428        // This is the minimum number of bytes that we should be able to read, when parsing a
429        // valid connection request. I.e. `b"connect 0\n".len()`.
430        const MIN_READ_LEN: usize = 10;
431
432        // Bring in the minimum number of bytes that we should be able to read.
433        stream
434            .read_exact(&mut buf[..MIN_READ_LEN])
435            .map_err(VsockUnixBackendError::UnixRead)?;
436
437        // Now, finish reading the destination port number, by bringing in one byte at a time,
438        // until we reach an EOL terminator (or our buffer space runs out).  Yeah, not
439        // particularly proud of this approach, but it will have to do for now.
440        let mut blen = MIN_READ_LEN;
441        while buf[blen - 1] != b'\n' && blen < buf.len() {
442            stream
443                .read_exact(&mut buf[blen..=blen])
444                .map_err(VsockUnixBackendError::UnixRead)?;
445            blen += 1;
446        }
447
448        let mut word_iter = std::str::from_utf8(&buf[..blen])
449            .map_err(|_| VsockUnixBackendError::InvalidPortRequest)?
450            .split_whitespace();
451
452        word_iter
453            .next()
454            .ok_or(VsockUnixBackendError::InvalidPortRequest)
455            .and_then(|word| {
456                if word.to_lowercase() == "connect" {
457                    Ok(())
458                } else {
459                    Err(VsockUnixBackendError::InvalidPortRequest)
460                }
461            })
462            .and_then(|_| {
463                word_iter
464                    .next()
465                    .ok_or(VsockUnixBackendError::InvalidPortRequest)
466            })
467            .and_then(|word| {
468                word.parse::<u32>()
469                    .map_err(|_| VsockUnixBackendError::InvalidPortRequest)
470            })
471            .map_err(|_| VsockUnixBackendError::InvalidPortRequest)
472    }
473
474    /// Add a new connection to the active connection pool.
475    fn add_connection(
476        &mut self,
477        key: ConnMapKey,
478        conn: MuxerConnection,
479    ) -> Result<(), VsockUnixBackendError> {
480        // We might need to make room for this new connection, so let's sweep the kill queue
481        // first.  It's fine to do this here because:
482        // - unless the kill queue is out of sync, this is a pretty inexpensive operation; and
483        // - we are under no pressure to respect any accurate timing for connection termination.
484        self.sweep_killq();
485
486        if self.conn_map.len() >= defs::MAX_CONNECTIONS {
487            info!(
488                "vsock: muxer connection limit reached ({})",
489                defs::MAX_CONNECTIONS
490            );
491            return Err(VsockUnixBackendError::TooManyConnections);
492        }
493
494        self.add_listener(
495            conn.as_raw_fd(),
496            EpollListener::Connection {
497                key,
498                evset: conn.get_polled_evset(),
499            },
500        )
501        .map(|_| {
502            if conn.has_pending_rx() {
503                // We can safely ignore any error in adding a connection RX indication. Worst
504                // case scenario, the RX queue will get desynchronized, but we'll handle that
505                // the next time we need to yield an RX packet.
506                self.rxq.push(MuxerRx::ConnRx(key));
507            }
508            self.conn_map.insert(key, conn);
509            METRICS.conns_added.inc();
510        })
511    }
512
513    /// Remove a connection from the active connection poll.
514    fn remove_connection(&mut self, key: ConnMapKey) {
515        if let Some(conn) = self.conn_map.remove(&key) {
516            self.remove_listener(conn.as_raw_fd());
517            METRICS.conns_removed.inc();
518        }
519        self.free_local_port(key.local_port);
520    }
521
522    /// Schedule a connection for immediate termination.
523    /// I.e. as soon as we can also let our peer know we're dropping the connection, by sending
524    /// it an RST packet.
525    fn kill_connection(&mut self, key: ConnMapKey) {
526        let mut had_rx = false;
527        METRICS.conns_killed.inc();
528
529        self.conn_map.entry(key).and_modify(|conn| {
530            had_rx = conn.has_pending_rx();
531            conn.kill();
532        });
533        // This connection will now have an RST packet to yield, so we need to add it to the RX
534        // queue.  However, there's no point in doing that if it was already in the queue.
535        if !had_rx {
536            // We can safely ignore any error in adding a connection RX indication. Worst case
537            // scenario, the RX queue will get desynchronized, but we'll handle that the next
538            // time we need to yield an RX packet.
539            self.rxq.push(MuxerRx::ConnRx(key));
540        }
541    }
542
543    /// Register a new epoll listener under the muxer's nested epoll FD.
544    fn add_listener(
545        &mut self,
546        fd: RawFd,
547        listener: EpollListener,
548    ) -> Result<(), VsockUnixBackendError> {
549        let evset = match listener {
550            EpollListener::Connection { evset, .. } => evset,
551            EpollListener::LocalStream(_) => EventSet::IN,
552            EpollListener::HostSock => EventSet::IN,
553        };
554
555        self.epoll
556            .ctl(
557                ControlOperation::Add,
558                fd,
559                EpollEvent::new(evset, u64::try_from(fd).unwrap()),
560            )
561            .map(|_| {
562                self.listener_map.insert(fd, listener);
563            })
564            .map_err(VsockUnixBackendError::EpollAdd)?;
565
566        Ok(())
567    }
568
569    /// Remove (and return) a previously registered epoll listener.
570    fn remove_listener(&mut self, fd: RawFd) -> Option<EpollListener> {
571        let maybe_listener = self.listener_map.remove(&fd);
572
573        if maybe_listener.is_some() {
574            self.epoll
575                .ctl(ControlOperation::Delete, fd, EpollEvent::default())
576                .unwrap_or_else(|err| {
577                    warn!(
578                        "vosck muxer: error removing epoll listener for fd {:?}: {:?}",
579                        fd, err
580                    );
581                });
582        }
583
584        maybe_listener
585    }
586
587    /// Allocate a host-side port to be assigned to a new host-initiated connection.
588    fn allocate_local_port(&mut self) -> u32 {
589        // TODO: this doesn't seem very space-efficient.
590        // Mybe rewrite this to limit port range and use a bitmap?
591        //
592
593        loop {
594            self.local_port_last = (self.local_port_last + 1) & !(1 << 31) | (1 << 30);
595            if self.local_port_set.insert(self.local_port_last) {
596                break;
597            }
598        }
599        self.local_port_last
600    }
601
602    /// Mark a previously used host-side port as free.
603    fn free_local_port(&mut self, port: u32) {
604        self.local_port_set.remove(&port);
605    }
606
607    /// Handle a new connection request comming from our peer (the guest vsock driver).
608    ///
609    /// This will attempt to connect to a host-side Unix socket, expected to be listening at
610    /// the file system path corresponing to the destination port. If successful, a new
611    /// connection object will be created and added to the connection pool. On failure, a new
612    /// RST packet will be scheduled for delivery to the guest.
613    fn handle_peer_request_pkt(&mut self, pkt: &VsockPacketTx) {
614        let port_path = format!("{}_{}", self.host_sock_path, pkt.hdr.dst_port());
615
616        UnixStream::connect(port_path)
617            .and_then(|stream| stream.set_nonblocking(true).map(|_| stream))
618            .map_err(VsockUnixBackendError::UnixConnect)
619            .and_then(|stream| {
620                self.add_connection(
621                    ConnMapKey {
622                        local_port: pkt.hdr.dst_port(),
623                        peer_port: pkt.hdr.src_port(),
624                    },
625                    MuxerConnection::new_peer_init(
626                        stream,
627                        uapi::VSOCK_HOST_CID,
628                        self.cid,
629                        pkt.hdr.dst_port(),
630                        pkt.hdr.src_port(),
631                        pkt.hdr.buf_alloc(),
632                    ),
633                )
634            })
635            .unwrap_or_else(|_| self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port()));
636    }
637
638    /// Perform an action that might mutate a connection's state.
639    ///
640    /// This is used as shorthand for repetitive tasks that need to be performed after a
641    /// connection object mutates. E.g.
642    /// - update the connection's epoll listener;
643    /// - schedule the connection to be queried for RX data;
644    /// - kill the connection if an unrecoverable error occurs.
645    fn apply_conn_mutation<F>(&mut self, key: ConnMapKey, mut_fn: F)
646    where
647        F: FnOnce(&mut MuxerConnection),
648    {
649        if let Some(conn) = self.conn_map.get_mut(&key) {
650            let had_rx = conn.has_pending_rx();
651            let was_expiring = conn.will_expire();
652            let prev_state = conn.state();
653
654            mut_fn(conn);
655
656            // If this is a host-initiated connection that has just become established, we'll have
657            // to send an ack message to the host end.
658            if prev_state == ConnState::LocalInit && conn.state() == ConnState::Established {
659                let msg = format!("OK {}\n", key.local_port);
660                match conn.send_bytes_raw(msg.as_bytes()) {
661                    Ok(written) if written == msg.len() => (),
662                    Ok(_) => {
663                        // If we can't write a dozen bytes to a pristine connection something
664                        // must be really wrong. Killing it.
665                        conn.kill();
666                        warn!("vsock: unable to fully write connection ack msg.");
667                    }
668                    Err(err) => {
669                        conn.kill();
670                        warn!("vsock: unable to ack host connection: {:?}", err);
671                    }
672                };
673            }
674
675            // If the connection wasn't previously scheduled for RX, add it to our RX queue.
676            if !had_rx && conn.has_pending_rx() {
677                self.rxq.push(MuxerRx::ConnRx(key));
678            }
679
680            // If the connection wasn't previously scheduled for termination, add it to the
681            // kill queue.
682            if !was_expiring && conn.will_expire() {
683                // It's safe to unwrap here, since `conn.will_expire()` already guaranteed that
684                // an `conn.expiry` is available.
685                self.killq.push(key, conn.expiry().unwrap());
686            }
687
688            let fd = conn.as_raw_fd();
689            let new_evset = conn.get_polled_evset();
690            if new_evset.is_empty() {
691                // If the connection no longer needs epoll notifications, remove its listener
692                // from our list.
693                self.remove_listener(fd);
694                return;
695            }
696            if let Some(EpollListener::Connection { evset, .. }) = self.listener_map.get_mut(&fd) {
697                if *evset != new_evset {
698                    // If the set of events that the connection is interested in has changed,
699                    // we need to update its epoll listener.
700                    debug!(
701                        "vsock: updating listener for (lp={}, pp={}): old={:?}, new={:?}",
702                        key.local_port, key.peer_port, *evset, new_evset
703                    );
704
705                    *evset = new_evset;
706                    self.epoll
707                        .ctl(
708                            ControlOperation::Modify,
709                            fd,
710                            EpollEvent::new(new_evset, u64::try_from(fd).unwrap()),
711                        )
712                        .unwrap_or_else(|err| {
713                            // This really shouldn't happen, like, ever. However, "famous last
714                            // words" and all that, so let's just kill it with fire, and walk away.
715                            self.kill_connection(key);
716                            error!(
717                                "vsock: error updating epoll listener for (lp={}, pp={}): {:?}",
718                                key.local_port, key.peer_port, err
719                            );
720                            METRICS.muxer_event_fails.inc();
721                        });
722                }
723            } else {
724                // The connection had previously asked to be removed from the listener map (by
725                // returning an empty event set via `get_polled_fd()`), but now wants back in.
726                self.add_listener(
727                    fd,
728                    EpollListener::Connection {
729                        key,
730                        evset: new_evset,
731                    },
732                )
733                .unwrap_or_else(|err| {
734                    self.kill_connection(key);
735                    error!(
736                        "vsock: error updating epoll listener for (lp={}, pp={}): {:?}",
737                        key.local_port, key.peer_port, err
738                    );
739                    METRICS.muxer_event_fails.inc();
740                });
741            }
742        }
743    }
744
745    /// Check if any connections have timed out, and if so, schedule them for immediate
746    /// termination.
747    fn sweep_killq(&mut self) {
748        while let Some(key) = self.killq.pop() {
749            // Connections don't get removed from the kill queue when their kill timer is
750            // disarmed, since that would be a costly operation. This means we must check if
751            // the connection has indeed expired, prior to killing it.
752            let mut kill = false;
753            self.conn_map
754                .entry(key)
755                .and_modify(|conn| kill = conn.has_expired());
756            if kill {
757                self.kill_connection(key);
758            }
759        }
760
761        if self.killq.is_empty() && !self.killq.is_synced() {
762            self.killq = MuxerKillQ::from_conn_map(&self.conn_map);
763            METRICS.killq_resync.inc();
764            // If we've just re-created the kill queue, we can sweep it again; maybe there's
765            // more to kill.
766            self.sweep_killq();
767        }
768    }
769
770    /// Enqueue an RST packet into `self.rxq`.
771    ///
772    /// Enqueue errors aren't propagated up the call chain, since there is nothing we can do to
773    /// handle them. We do, however, log a warning, since not being able to enqueue an RST
774    /// packet means we have to drop it, which is not normal operation.
775    fn enq_rst(&mut self, local_port: u32, peer_port: u32) {
776        let pushed = self.rxq.push(MuxerRx::RstPkt {
777            local_port,
778            peer_port,
779        });
780        if !pushed {
781            warn!(
782                "vsock: muxer.rxq full; dropping RST packet for lp={}, pp={}",
783                local_port, peer_port
784            );
785        }
786    }
787}
788
789#[cfg(test)]
790mod tests {
791    use std::io::{Read, Write};
792    use std::ops::Drop;
793    use std::os::unix::net::{UnixListener, UnixStream};
794    use std::path::{Path, PathBuf};
795
796    use vmm_sys_util::tempfile::TempFile;
797
798    use super::super::super::csm::defs as csm_defs;
799    use super::*;
800    use crate::devices::virtio::vsock::device::{RXQ_INDEX, TXQ_INDEX};
801    use crate::devices::virtio::vsock::test_utils;
802    use crate::devices::virtio::vsock::test_utils::TestContext as VsockTestContext;
803
804    const PEER_CID: u64 = 3;
805    const PEER_BUF_ALLOC: u32 = 64 * 1024;
806
807    #[derive(Debug)]
808    struct MuxerTestContext {
809        _vsock_test_ctx: VsockTestContext,
810        // Two views of the same in-memory packet. rx-view for writing, tx-view for reading
811        rx_pkt: VsockPacketRx,
812        tx_pkt: VsockPacketTx,
813        muxer: VsockMuxer,
814    }
815
816    impl Drop for MuxerTestContext {
817        fn drop(&mut self) {
818            std::fs::remove_file(self.muxer.host_sock_path.as_str()).unwrap();
819        }
820    }
821
822    // Create a TempFile with a given prefix and return it as a nice String
823    fn get_file(fprefix: &str) -> String {
824        let listener_path = TempFile::new_with_prefix(fprefix).unwrap();
825        listener_path
826            .as_path()
827            .as_os_str()
828            .to_str()
829            .unwrap()
830            .to_owned()
831    }
832
833    impl MuxerTestContext {
834        fn new(name: &str) -> Self {
835            let vsock_test_ctx = VsockTestContext::new();
836            let mut handler_ctx = vsock_test_ctx.create_event_handler_context();
837            let mut rx_pkt = VsockPacketRx::new().unwrap();
838            rx_pkt
839                .parse(
840                    &vsock_test_ctx.mem,
841                    handler_ctx.device.queues[RXQ_INDEX].pop().unwrap().unwrap(),
842                )
843                .unwrap();
844            let mut tx_pkt = VsockPacketTx::default();
845            tx_pkt
846                .parse(
847                    &vsock_test_ctx.mem,
848                    handler_ctx.device.queues[TXQ_INDEX].pop().unwrap().unwrap(),
849                )
850                .unwrap();
851
852            let muxer = VsockMuxer::new(PEER_CID, get_file(name)).unwrap();
853            Self {
854                _vsock_test_ctx: vsock_test_ctx,
855                rx_pkt,
856                tx_pkt,
857                muxer,
858            }
859        }
860
861        fn init_tx_pkt(&mut self, local_port: u32, peer_port: u32, op: u16) -> &mut VsockPacketTx {
862            self.tx_pkt
863                .hdr
864                .set_type(uapi::VSOCK_TYPE_STREAM)
865                .set_src_cid(PEER_CID)
866                .set_dst_cid(uapi::VSOCK_HOST_CID)
867                .set_src_port(peer_port)
868                .set_dst_port(local_port)
869                .set_op(op)
870                .set_buf_alloc(PEER_BUF_ALLOC);
871            &mut self.tx_pkt
872        }
873
874        fn init_data_tx_pkt(
875            &mut self,
876            local_port: u32,
877            peer_port: u32,
878            mut data: &[u8],
879        ) -> &mut VsockPacketTx {
880            assert!(data.len() <= self.tx_pkt.buf_size() as usize);
881            let tx_pkt = self.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_RW);
882            tx_pkt.hdr.set_len(u32::try_from(data.len()).unwrap());
883
884            let data_len = data.len().try_into().unwrap(); // store in tmp var to make borrow checker happy.
885            self.rx_pkt
886                .read_at_offset_from(&mut data, 0, data_len)
887                .unwrap();
888            &mut self.tx_pkt
889        }
890
891        fn send(&mut self) {
892            self.muxer.send_pkt(&self.tx_pkt).unwrap();
893        }
894
895        fn recv(&mut self) {
896            self.muxer.recv_pkt(&mut self.rx_pkt).unwrap();
897        }
898
899        fn notify_muxer(&mut self) {
900            self.muxer.notify(EventSet::IN);
901        }
902
903        fn count_epoll_listeners(&self) -> (usize, usize) {
904            let mut local_lsn_count = 0usize;
905            let mut conn_lsn_count = 0usize;
906            for key in self.muxer.listener_map.values() {
907                match key {
908                    EpollListener::LocalStream(_) => local_lsn_count += 1,
909                    EpollListener::Connection { .. } => conn_lsn_count += 1,
910                    _ => (),
911                };
912            }
913            (local_lsn_count, conn_lsn_count)
914        }
915
916        fn create_local_listener(&self, port: u32) -> LocalListener {
917            LocalListener::new(format!("{}_{}", self.muxer.host_sock_path, port))
918        }
919
920        fn local_connect(&mut self, peer_port: u32) -> (UnixStream, u32) {
921            let (init_local_lsn_count, init_conn_lsn_count) = self.count_epoll_listeners();
922
923            let mut stream = UnixStream::connect(self.muxer.host_sock_path.clone()).unwrap();
924            stream.set_nonblocking(true).unwrap();
925            // The muxer would now get notified of a new connection having arrived at its Unix
926            // socket, so it can accept it.
927            self.notify_muxer();
928
929            // Just after having accepted a new local connection, the muxer should've added a new
930            // `LocalStream` listener to its `listener_map`.
931            let (local_lsn_count, _) = self.count_epoll_listeners();
932            assert_eq!(local_lsn_count, init_local_lsn_count + 1);
933
934            let buf = format!("CONNECT {}\n", peer_port);
935            stream.write_all(buf.as_bytes()).unwrap();
936            // The muxer would now get notified that data is available for reading from the locally
937            // initiated connection.
938            self.notify_muxer();
939
940            // Successfully reading and parsing the connection request should have removed the
941            // LocalStream epoll listener and added a Connection epoll listener.
942            let (local_lsn_count, conn_lsn_count) = self.count_epoll_listeners();
943            assert_eq!(local_lsn_count, init_local_lsn_count);
944            assert_eq!(conn_lsn_count, init_conn_lsn_count + 1);
945
946            // A LocalInit connection should've been added to the muxer connection map.  A new
947            // local port should also have been allocated for the new LocalInit connection.
948            let local_port = self.muxer.local_port_last;
949            let key = ConnMapKey {
950                local_port,
951                peer_port,
952            };
953            assert!(self.muxer.conn_map.contains_key(&key));
954            assert!(self.muxer.local_port_set.contains(&local_port));
955
956            // A connection request for the peer should now be available from the muxer.
957            assert!(self.muxer.has_pending_rx());
958            self.recv();
959            assert_eq!(self.rx_pkt.hdr.op(), uapi::VSOCK_OP_REQUEST);
960            assert_eq!(self.rx_pkt.hdr.dst_port(), peer_port);
961            assert_eq!(self.rx_pkt.hdr.src_port(), local_port);
962
963            self.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_RESPONSE);
964            self.send();
965
966            let mut buf = [0u8; 32];
967            let len = stream.read(&mut buf[..]).unwrap();
968            assert_eq!(&buf[..len], format!("OK {}\n", local_port).as_bytes());
969
970            (stream, local_port)
971        }
972    }
973
974    #[derive(Debug)]
975    struct LocalListener {
976        path: PathBuf,
977        sock: UnixListener,
978    }
979    impl LocalListener {
980        fn new<P: AsRef<Path> + Clone + Debug>(path: P) -> Self {
981            let path_buf = path.as_ref().to_path_buf();
982            let sock = UnixListener::bind(path).unwrap();
983            sock.set_nonblocking(true).unwrap();
984            Self {
985                path: path_buf,
986                sock,
987            }
988        }
989        fn accept(&mut self) -> UnixStream {
990            let (stream, _) = self.sock.accept().unwrap();
991            stream.set_nonblocking(true).unwrap();
992            stream
993        }
994    }
995    impl Drop for LocalListener {
996        fn drop(&mut self) {
997            std::fs::remove_file(&self.path).unwrap();
998        }
999    }
1000
1001    #[test]
1002    fn test_muxer_epoll_listener() {
1003        let ctx = MuxerTestContext::new("muxer_epoll_listener");
1004        assert_eq!(ctx.muxer.as_raw_fd(), ctx.muxer.epoll.as_raw_fd());
1005        assert_eq!(ctx.muxer.get_polled_evset(), EventSet::IN);
1006    }
1007
1008    #[test]
1009    fn test_muxer_epoll_listener_regression() {
1010        let mut ctx = MuxerTestContext::new("muxer_epoll_listener");
1011        ctx.local_connect(1025);
1012
1013        let (_, conn) = ctx.muxer.conn_map.iter().next().unwrap();
1014
1015        assert_eq!(conn.get_polled_evset(), EventSet::IN);
1016
1017        assert_eq!(METRICS.conn_event_fails.count(), 0);
1018
1019        let conn_eventfd = conn.as_raw_fd();
1020
1021        ctx.muxer.handle_event(conn_eventfd, EventSet::OUT);
1022
1023        assert_eq!(METRICS.conn_event_fails.count(), 1);
1024    }
1025
1026    #[test]
1027    fn test_bad_peer_pkt() {
1028        const LOCAL_PORT: u32 = 1026;
1029        const PEER_PORT: u32 = 1025;
1030        const SOCK_DGRAM: u16 = 2;
1031
1032        let mut ctx = MuxerTestContext::new("bad_peer_pkt");
1033        let tx_pkt = ctx.init_tx_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST);
1034        tx_pkt.hdr.set_type(SOCK_DGRAM);
1035        ctx.send();
1036
1037        // The guest sent a SOCK_DGRAM packet. Per the vsock spec, we need to reply with an RST
1038        // packet, since vsock only supports stream sockets.
1039        assert!(ctx.muxer.has_pending_rx());
1040        ctx.recv();
1041        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST);
1042        assert_eq!(ctx.rx_pkt.hdr.src_cid(), uapi::VSOCK_HOST_CID);
1043        assert_eq!(ctx.rx_pkt.hdr.dst_cid(), PEER_CID);
1044        assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT);
1045        assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT);
1046
1047        // Any orphan (i.e. without a connection), non-RST packet, should be replied to with an
1048        // RST.
1049        let bad_ops = [
1050            uapi::VSOCK_OP_RESPONSE,
1051            uapi::VSOCK_OP_CREDIT_REQUEST,
1052            uapi::VSOCK_OP_CREDIT_UPDATE,
1053            uapi::VSOCK_OP_SHUTDOWN,
1054            uapi::VSOCK_OP_RW,
1055        ];
1056        for op in bad_ops.iter() {
1057            ctx.init_tx_pkt(LOCAL_PORT, PEER_PORT, *op);
1058            ctx.send();
1059            assert!(ctx.muxer.has_pending_rx());
1060            ctx.recv();
1061            assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST);
1062            assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT);
1063            assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT);
1064        }
1065
1066        // Any packet addressed to anything other than VSOCK_VHOST_CID should get dropped.
1067        assert!(!ctx.muxer.has_pending_rx());
1068        let tx_pkt = ctx.init_tx_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST);
1069        tx_pkt.hdr.set_dst_cid(uapi::VSOCK_HOST_CID + 1);
1070        ctx.send();
1071        assert!(!ctx.muxer.has_pending_rx());
1072    }
1073
1074    #[test]
1075    fn test_peer_connection() {
1076        const LOCAL_PORT: u32 = 1026;
1077        const PEER_PORT: u32 = 1025;
1078
1079        let mut ctx = MuxerTestContext::new("peer_connection");
1080
1081        // Test peer connection refused.
1082        ctx.init_tx_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST);
1083        ctx.send();
1084        ctx.recv();
1085        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST);
1086        assert_eq!(ctx.rx_pkt.hdr.len(), 0);
1087        assert_eq!(ctx.rx_pkt.hdr.src_cid(), uapi::VSOCK_HOST_CID);
1088        assert_eq!(ctx.rx_pkt.hdr.dst_cid(), PEER_CID);
1089        assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT);
1090        assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT);
1091
1092        // Test peer connection accepted.
1093        let mut listener = ctx.create_local_listener(LOCAL_PORT);
1094        ctx.init_tx_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST);
1095        ctx.send();
1096        assert_eq!(ctx.muxer.conn_map.len(), 1);
1097        let mut stream = listener.accept();
1098        ctx.recv();
1099        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE);
1100        assert_eq!(ctx.rx_pkt.hdr.len(), 0);
1101        assert_eq!(ctx.rx_pkt.hdr.src_cid(), uapi::VSOCK_HOST_CID);
1102        assert_eq!(ctx.rx_pkt.hdr.dst_cid(), PEER_CID);
1103        assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT);
1104        assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT);
1105        let key = ConnMapKey {
1106            local_port: LOCAL_PORT,
1107            peer_port: PEER_PORT,
1108        };
1109        assert!(ctx.muxer.conn_map.contains_key(&key));
1110
1111        // Test guest -> host data flow.
1112        let data = [1, 2, 3, 4];
1113        ctx.init_data_tx_pkt(LOCAL_PORT, PEER_PORT, &data);
1114        ctx.send();
1115        let mut buf = vec![0; data.len()];
1116        stream.read_exact(buf.as_mut_slice()).unwrap();
1117        assert_eq!(buf.as_slice(), data);
1118
1119        // Test host -> guest data flow.
1120        let data = [5u8, 6, 7, 8];
1121        stream.write_all(&data).unwrap();
1122
1123        // When data is available on the local stream, an EPOLLIN event would normally be delivered
1124        // to the muxer's nested epoll FD. For testing only, we can fake that event notification
1125        // here.
1126        ctx.notify_muxer();
1127        // After being notified, the muxer should've figured out that RX data was available for one
1128        // of its connections, so it should now be reporting that it can fill in an RX packet.
1129        assert!(ctx.muxer.has_pending_rx());
1130        ctx.recv();
1131        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RW);
1132        assert_eq!(ctx.rx_pkt.hdr.src_port(), LOCAL_PORT);
1133        assert_eq!(ctx.rx_pkt.hdr.dst_port(), PEER_PORT);
1134
1135        let buf = test_utils::read_packet_data(&ctx.tx_pkt, 4);
1136        assert_eq!(&buf, &data);
1137
1138        assert!(!ctx.muxer.has_pending_rx());
1139    }
1140
1141    #[test]
1142    fn test_local_connection() {
1143        // Test guest -> host data flow.
1144        let mut ctx = MuxerTestContext::new("local_connection");
1145        let peer_port = 1025;
1146        let (mut stream, local_port) = ctx.local_connect(peer_port);
1147
1148        let data = [1, 2, 3, 4];
1149        ctx.init_data_tx_pkt(local_port, peer_port, &data);
1150        ctx.send();
1151
1152        let mut buf = vec![0u8; data.len()];
1153        stream.read_exact(buf.as_mut_slice()).unwrap();
1154        assert_eq!(buf.as_slice(), &data);
1155
1156        // Test host -> guest data flow.
1157        let mut ctx = MuxerTestContext::new("local_connection");
1158        let peer_port = 1025;
1159        let (mut stream, local_port) = ctx.local_connect(peer_port);
1160
1161        let data = [5, 6, 7, 8];
1162        stream.write_all(&data).unwrap();
1163        ctx.notify_muxer();
1164
1165        assert!(ctx.muxer.has_pending_rx());
1166        ctx.recv();
1167        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RW);
1168        assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port);
1169        assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port);
1170
1171        let buf = test_utils::read_packet_data(&ctx.tx_pkt, 4);
1172        assert_eq!(&buf, &data);
1173    }
1174
1175    #[test]
1176    fn test_local_close() {
1177        let peer_port = 1025;
1178        let mut ctx = MuxerTestContext::new("local_close");
1179        let local_port;
1180        {
1181            let (_stream, local_port_) = ctx.local_connect(peer_port);
1182            local_port = local_port_;
1183        }
1184        // Local var `_stream` was now dropped, thus closing the local stream. After the muxer gets
1185        // notified via EPOLLIN, it should attempt to gracefully shutdown the connection, issuing a
1186        // VSOCK_OP_SHUTDOWN with both no-more-send and no-more-recv indications set.
1187        ctx.notify_muxer();
1188        assert!(ctx.muxer.has_pending_rx());
1189        ctx.recv();
1190        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_SHUTDOWN);
1191        assert_ne!(ctx.rx_pkt.hdr.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND, 0);
1192        assert_ne!(ctx.rx_pkt.hdr.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV, 0);
1193        assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port);
1194        assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port);
1195
1196        // The connection should get removed (and its local port freed), after the peer replies
1197        // with an RST.
1198        ctx.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_RST);
1199        ctx.send();
1200        let key = ConnMapKey {
1201            local_port,
1202            peer_port,
1203        };
1204        assert!(!ctx.muxer.conn_map.contains_key(&key));
1205        assert!(!ctx.muxer.local_port_set.contains(&local_port));
1206    }
1207
1208    #[test]
1209    fn test_peer_close() {
1210        let peer_port = 1025;
1211        let local_port = 1026;
1212        let mut ctx = MuxerTestContext::new("peer_close");
1213
1214        let mut sock = ctx.create_local_listener(local_port);
1215        ctx.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_REQUEST);
1216        ctx.send();
1217        let mut stream = sock.accept();
1218
1219        assert!(ctx.muxer.has_pending_rx());
1220        ctx.recv();
1221        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE);
1222        assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port);
1223        assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port);
1224        let key = ConnMapKey {
1225            local_port,
1226            peer_port,
1227        };
1228        assert!(ctx.muxer.conn_map.contains_key(&key));
1229
1230        // Emulate a full shutdown from the peer (no-more-send + no-more-recv).
1231        let tx_pkt = ctx.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_SHUTDOWN);
1232        tx_pkt.hdr.set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_SEND);
1233        tx_pkt.hdr.set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_RCV);
1234        ctx.send();
1235
1236        // Now, the muxer should remove the connection from its map, and reply with an RST.
1237        assert!(ctx.muxer.has_pending_rx());
1238        ctx.recv();
1239        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST);
1240        assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port);
1241        assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port);
1242        let key = ConnMapKey {
1243            local_port,
1244            peer_port,
1245        };
1246        assert!(!ctx.muxer.conn_map.contains_key(&key));
1247
1248        // The muxer should also drop / close the local Unix socket for this connection.
1249        let mut buf = vec![0u8; 16];
1250        assert_eq!(stream.read(buf.as_mut_slice()).unwrap(), 0);
1251    }
1252
1253    #[test]
1254    fn test_muxer_rxq() {
1255        let mut ctx = MuxerTestContext::new("muxer_rxq");
1256        let local_port = 1026;
1257        let peer_port_first = 1025;
1258        let mut listener = ctx.create_local_listener(local_port);
1259        let mut streams: Vec<UnixStream> = Vec::new();
1260
1261        for peer_port in peer_port_first..peer_port_first + defs::MUXER_RXQ_SIZE {
1262            ctx.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_REQUEST);
1263            ctx.send();
1264            streams.push(listener.accept());
1265        }
1266
1267        // The muxer RX queue should now be full (with connection reponses), but still
1268        // synchronized.
1269        assert!(ctx.muxer.rxq.is_synced());
1270
1271        // One more queued reply should desync the RX queue.
1272        ctx.init_tx_pkt(
1273            local_port,
1274            peer_port_first + defs::MUXER_RXQ_SIZE,
1275            uapi::VSOCK_OP_REQUEST,
1276        );
1277        ctx.send();
1278        assert!(!ctx.muxer.rxq.is_synced());
1279
1280        // With an out-of-sync queue, an RST should evict any non-RST packet from the queue, and
1281        // take its place. We'll check that by making sure that the last packet popped from the
1282        // queue is an RST.
1283        ctx.init_tx_pkt(local_port + 1, peer_port_first, uapi::VSOCK_OP_REQUEST);
1284        ctx.send();
1285
1286        for peer_port in peer_port_first..peer_port_first + defs::MUXER_RXQ_SIZE - 1 {
1287            ctx.recv();
1288            assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE);
1289            // The response order should hold. The evicted response should have been the last
1290            // enqueued.
1291            assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port);
1292        }
1293        // There should be one more packet in the queue: the RST.
1294        assert_eq!(ctx.muxer.rxq.len(), 1);
1295        ctx.recv();
1296        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST);
1297
1298        // The queue should now be empty, but out-of-sync, so the muxer should report it has some
1299        // pending RX.
1300        assert!(ctx.muxer.rxq.is_empty());
1301        assert!(!ctx.muxer.rxq.is_synced());
1302        assert!(ctx.muxer.has_pending_rx());
1303
1304        // The next recv should sync the queue back up. It should also yield one of the two
1305        // responses that are still left:
1306        // - the one that desynchronized the queue; and
1307        // - the one that got evicted by the RST.
1308        ctx.recv();
1309        assert!(ctx.muxer.rxq.is_synced());
1310        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE);
1311
1312        assert!(ctx.muxer.has_pending_rx());
1313        ctx.recv();
1314        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE);
1315    }
1316
1317    #[test]
1318    fn test_muxer_killq() {
1319        let mut ctx = MuxerTestContext::new("muxer_killq");
1320        let local_port = 1026;
1321        let peer_port_first = 1025;
1322        let peer_port_last = peer_port_first + defs::MUXER_KILLQ_SIZE;
1323        let mut listener = ctx.create_local_listener(local_port);
1324
1325        // Save metrics relevant for this test.
1326        let conns_added = METRICS.conns_added.count();
1327        let conns_killed = METRICS.conns_killed.count();
1328        let conns_removed = METRICS.conns_removed.count();
1329        let killq_resync = METRICS.killq_resync.count();
1330
1331        for peer_port in peer_port_first..=peer_port_last {
1332            ctx.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_REQUEST);
1333            ctx.send();
1334            ctx.notify_muxer();
1335            ctx.recv();
1336            assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE);
1337            assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port);
1338            assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port);
1339            {
1340                let _stream = listener.accept();
1341            }
1342            ctx.notify_muxer();
1343            ctx.recv();
1344            assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_SHUTDOWN);
1345            assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port);
1346            assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port);
1347            // The kill queue should be synchronized, up until the `defs::MUXER_KILLQ_SIZE`th
1348            // connection we schedule for termination.
1349            assert_eq!(
1350                ctx.muxer.killq.is_synced(),
1351                peer_port < peer_port_first + defs::MUXER_KILLQ_SIZE
1352            );
1353        }
1354
1355        assert!(!ctx.muxer.killq.is_synced());
1356        assert!(!ctx.muxer.has_pending_rx());
1357
1358        // Wait for the kill timers to expire.
1359        std::thread::sleep(std::time::Duration::from_millis(
1360            csm_defs::CONN_SHUTDOWN_TIMEOUT_MS,
1361        ));
1362
1363        // Trigger a kill queue sweep, by requesting a new connection.
1364        ctx.init_tx_pkt(local_port, peer_port_last + 1, uapi::VSOCK_OP_REQUEST);
1365        ctx.send();
1366
1367        // Check that MUXER_KILLQ_SIZE + 2 connections were added
1368        // We count +2, because there are two extra connections being
1369        // done outside of the loop.
1370        assert_eq!(
1371            METRICS.conns_added.count(),
1372            conns_added + u64::from(defs::MUXER_KILLQ_SIZE) + 2
1373        );
1374        // Check that MUXER_KILLQ_SIZE connections were killed
1375        assert_eq!(
1376            METRICS.conns_killed.count(),
1377            conns_killed + u64::from(defs::MUXER_KILLQ_SIZE)
1378        );
1379        // No connections should be removed at this point.
1380        assert_eq!(METRICS.conns_removed.count(), conns_removed);
1381
1382        assert_eq!(METRICS.killq_resync.count(), killq_resync + 1);
1383        // After sweeping the kill queue, it should now be synced (assuming the RX queue is larger
1384        // than the kill queue, since an RST packet will be queued for each killed connection).
1385        assert!(ctx.muxer.killq.is_synced());
1386        assert!(ctx.muxer.has_pending_rx());
1387        // There should be `defs::MUXER_KILLQ_SIZE` RSTs in the RX queue, from terminating the
1388        // dying connections in the recent killq sweep.
1389        for _p in peer_port_first..peer_port_last {
1390            ctx.recv();
1391            assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RST);
1392            assert_eq!(ctx.rx_pkt.hdr.src_port(), local_port);
1393        }
1394
1395        // The connections should have been removed here.
1396        assert_eq!(
1397            METRICS.conns_removed.count(),
1398            conns_removed + u64::from(defs::MUXER_KILLQ_SIZE)
1399        );
1400
1401        // There should be one more packet in the RX queue: the connection response our request
1402        // that triggered the kill queue sweep.
1403        ctx.recv();
1404        assert_eq!(ctx.rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE);
1405        assert_eq!(ctx.rx_pkt.hdr.dst_port(), peer_port_last + 1);
1406
1407        assert!(!ctx.muxer.has_pending_rx());
1408    }
1409
1410    #[test]
1411    fn test_regression_handshake() {
1412        // Address one of the issues found while fixing the following issue:
1413        // https://github.com/firecracker-microvm/firecracker/issues/1751
1414        // This test checks that the handshake message is not accounted for
1415        let mut ctx = MuxerTestContext::new("regression_handshake");
1416        let peer_port = 1025;
1417
1418        // Create a local connection.
1419        let (_, local_port) = ctx.local_connect(peer_port);
1420
1421        // Get the connection from the connection map.
1422        let key = ConnMapKey {
1423            local_port,
1424            peer_port,
1425        };
1426        let conn = ctx.muxer.conn_map.get_mut(&key).unwrap();
1427
1428        // Check that fwd_cnt is 0 - "OK ..." was not accounted for.
1429        assert_eq!(conn.fwd_cnt().0, 0);
1430    }
1431
1432    #[test]
1433    fn test_regression_rxq_pop() {
1434        // Address one of the issues found while fixing the following issue:
1435        // https://github.com/firecracker-microvm/firecracker/issues/1751
1436        // This test checks that a connection is not popped out of the muxer
1437        // rxq when multiple flags are set
1438        let mut ctx = MuxerTestContext::new("regression_rxq_pop");
1439        let peer_port = 1025;
1440        let (mut stream, local_port) = ctx.local_connect(peer_port);
1441
1442        // Send some data.
1443        let data = [5u8, 6, 7, 8];
1444        stream.write_all(&data).unwrap();
1445        ctx.notify_muxer();
1446
1447        // Get the connection from the connection map.
1448        let key = ConnMapKey {
1449            local_port,
1450            peer_port,
1451        };
1452        let conn = ctx.muxer.conn_map.get_mut(&key).unwrap();
1453
1454        // Forcefully insert another flag.
1455        conn.insert_credit_update();
1456
1457        // Call recv twice in order to check that the connection is still
1458        // in the rxq.
1459        assert!(ctx.muxer.has_pending_rx());
1460        ctx.recv();
1461        assert!(ctx.muxer.has_pending_rx());
1462        ctx.recv();
1463
1464        // Since initially the connection had two flags set, now there should
1465        // not be any pending RX in the muxer.
1466        assert!(!ctx.muxer.has_pending_rx());
1467    }
1468
1469    #[test]
1470    fn test_vsock_basic_metrics() {
1471        // Save the metrics values that we need tested.
1472        let mut tx_packets_count = METRICS.tx_packets_count.count();
1473        let mut rx_packets_count = METRICS.rx_packets_count.count();
1474
1475        let tx_bytes_count = METRICS.tx_bytes_count.count();
1476        let rx_bytes_count = METRICS.rx_bytes_count.count();
1477
1478        let conns_added = METRICS.conns_added.count();
1479        let conns_removed = METRICS.conns_removed.count();
1480
1481        // Create a basic connection.
1482        let mut ctx = MuxerTestContext::new("vsock_basic_metrics");
1483        let peer_port = 1025;
1484        let (mut stream, local_port) = ctx.local_connect(peer_port);
1485
1486        // Once the handshake is done, we check that the TX bytes count has
1487        // not been increased.
1488        assert_eq!(METRICS.tx_bytes_count.count(), tx_bytes_count);
1489
1490        // Check that one packet was sent through the handshake.
1491        assert_eq!(METRICS.tx_packets_count.count(), tx_packets_count + 1);
1492        tx_packets_count = METRICS.tx_packets_count.count();
1493
1494        // Check that one packet was received through the handshake.
1495        assert_eq!(METRICS.rx_packets_count.count(), rx_packets_count + 1);
1496        rx_packets_count = METRICS.rx_packets_count.count();
1497
1498        // Check that a new connection was added.
1499        assert_eq!(METRICS.conns_added.count(), conns_added + 1);
1500
1501        // Send some data from guest to host.
1502        let data = [1, 2, 3, 4];
1503        ctx.init_data_tx_pkt(local_port, peer_port, &data);
1504        ctx.send();
1505
1506        // Check that tx_bytes was incremented.
1507        assert_eq!(
1508            METRICS.tx_bytes_count.count(),
1509            tx_bytes_count + data.len() as u64
1510        );
1511
1512        // Check that one packet was accounted for.
1513        assert_eq!(METRICS.tx_packets_count.count(), tx_packets_count + 1);
1514
1515        // Send some data from the host to the guest.
1516        let data = [1, 2, 3, 4, 5, 6];
1517        stream.write_all(&data).unwrap();
1518        ctx.notify_muxer();
1519        ctx.recv();
1520
1521        // Check that a packet was received.
1522        assert_eq!(METRICS.rx_packets_count.count(), rx_packets_count + 1);
1523
1524        // Check that the 6 bytes have been received.
1525        assert_eq!(
1526            METRICS.rx_bytes_count.count(),
1527            rx_bytes_count + data.len() as u64
1528        );
1529
1530        // Send a connection reset.
1531        ctx.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_RST);
1532        ctx.send();
1533
1534        // Check that the connection was removed.
1535        assert_eq!(METRICS.conns_removed.count(), conns_removed + 1);
1536    }
1537}