vmm/devices/virtio/mem/
request.rs

1// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use vm_memory::{Address, ByteValued, GuestAddress};
5
6use crate::devices::virtio::generated::virtio_mem;
7
8#[derive(Clone, Copy, Debug, PartialEq, Eq)]
9pub struct RequestedRange {
10    pub(crate) addr: GuestAddress,
11    pub(crate) nb_blocks: usize,
12}
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
15pub(crate) enum Request {
16    Plug(RequestedRange),
17    Unplug(RequestedRange),
18    UnplugAll,
19    State(RequestedRange),
20    Unsupported(u32),
21}
22
23// SAFETY: this is safe, trust me bro
24unsafe impl ByteValued for virtio_mem::virtio_mem_req {}
25
26impl From<virtio_mem::virtio_mem_req> for Request {
27    fn from(req: virtio_mem::virtio_mem_req) -> Self {
28        match req.type_.into() {
29            // SAFETY: union type is checked in the match
30            virtio_mem::VIRTIO_MEM_REQ_PLUG => unsafe {
31                Request::Plug(RequestedRange {
32                    addr: GuestAddress(req.u.plug.addr),
33                    nb_blocks: req.u.plug.nb_blocks.into(),
34                })
35            },
36            // SAFETY: union type is checked in the match
37            virtio_mem::VIRTIO_MEM_REQ_UNPLUG => unsafe {
38                Request::Unplug(RequestedRange {
39                    addr: GuestAddress(req.u.unplug.addr),
40                    nb_blocks: req.u.unplug.nb_blocks.into(),
41                })
42            },
43            virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL => Request::UnplugAll,
44            // SAFETY: union type is checked in the match
45            virtio_mem::VIRTIO_MEM_REQ_STATE => unsafe {
46                Request::State(RequestedRange {
47                    addr: GuestAddress(req.u.state.addr),
48                    nb_blocks: req.u.state.nb_blocks.into(),
49                })
50            },
51            t => Request::Unsupported(t),
52        }
53    }
54}
55
56#[repr(u16)]
57#[derive(Debug, Clone, Copy, Eq, PartialEq)]
58#[allow(clippy::cast_possible_truncation)]
59pub enum ResponseType {
60    Ack = virtio_mem::VIRTIO_MEM_RESP_ACK as u16,
61    Nack = virtio_mem::VIRTIO_MEM_RESP_NACK as u16,
62    Busy = virtio_mem::VIRTIO_MEM_RESP_BUSY as u16,
63    Error = virtio_mem::VIRTIO_MEM_RESP_ERROR as u16,
64}
65
66#[repr(u16)]
67#[derive(Debug, Clone, Copy, Eq, PartialEq)]
68#[allow(clippy::cast_possible_truncation)]
69pub enum BlockRangeState {
70    Plugged = virtio_mem::VIRTIO_MEM_STATE_PLUGGED as u16,
71    Unplugged = virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED as u16,
72    Mixed = virtio_mem::VIRTIO_MEM_STATE_MIXED as u16,
73}
74
75#[derive(Debug, Clone, Eq, PartialEq)]
76pub struct Response {
77    pub resp_type: ResponseType,
78    // Only for State requests
79    pub state: Option<BlockRangeState>,
80}
81
82impl Response {
83    pub(crate) fn error() -> Self {
84        Response {
85            resp_type: ResponseType::Error,
86            state: None,
87        }
88    }
89
90    pub(crate) fn ack() -> Self {
91        Response {
92            resp_type: ResponseType::Ack,
93            state: None,
94        }
95    }
96
97    pub(crate) fn ack_with_state(state: BlockRangeState) -> Self {
98        Response {
99            resp_type: ResponseType::Ack,
100            state: Some(state),
101        }
102    }
103
104    pub(crate) fn is_ack(&self) -> bool {
105        self.resp_type == ResponseType::Ack
106    }
107
108    pub(crate) fn is_error(&self) -> bool {
109        self.resp_type == ResponseType::Error
110    }
111}
112
113// SAFETY: Plain data structures
114unsafe impl ByteValued for virtio_mem::virtio_mem_resp {}
115
116impl From<Response> for virtio_mem::virtio_mem_resp {
117    fn from(resp: Response) -> Self {
118        let mut out = virtio_mem::virtio_mem_resp {
119            type_: resp.resp_type as u16,
120            ..Default::default()
121        };
122        if let Some(state) = resp.state {
123            out.u.state.state = state as u16;
124        }
125        out
126    }
127}
128
129#[cfg(test)]
130mod test_util {
131    use super::*;
132
133    // Implement the reverse conversions to use in test code.
134
135    impl From<Request> for virtio_mem::virtio_mem_req {
136        fn from(req: Request) -> virtio_mem::virtio_mem_req {
137            match req {
138                Request::Plug(r) => virtio_mem::virtio_mem_req {
139                    type_: virtio_mem::VIRTIO_MEM_REQ_PLUG.try_into().unwrap(),
140                    u: virtio_mem::virtio_mem_req__bindgen_ty_1 {
141                        plug: virtio_mem::virtio_mem_req_plug {
142                            addr: r.addr.raw_value(),
143                            nb_blocks: r.nb_blocks.try_into().unwrap(),
144                            ..Default::default()
145                        },
146                    },
147                    ..Default::default()
148                },
149                Request::Unplug(r) => virtio_mem::virtio_mem_req {
150                    type_: virtio_mem::VIRTIO_MEM_REQ_UNPLUG.try_into().unwrap(),
151                    u: virtio_mem::virtio_mem_req__bindgen_ty_1 {
152                        unplug: virtio_mem::virtio_mem_req_unplug {
153                            addr: r.addr.raw_value(),
154                            nb_blocks: r.nb_blocks.try_into().unwrap(),
155                            ..Default::default()
156                        },
157                    },
158                    ..Default::default()
159                },
160                Request::UnplugAll => virtio_mem::virtio_mem_req {
161                    type_: virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL.try_into().unwrap(),
162                    ..Default::default()
163                },
164                Request::State(r) => virtio_mem::virtio_mem_req {
165                    type_: virtio_mem::VIRTIO_MEM_REQ_STATE.try_into().unwrap(),
166                    u: virtio_mem::virtio_mem_req__bindgen_ty_1 {
167                        state: virtio_mem::virtio_mem_req_state {
168                            addr: r.addr.raw_value(),
169                            nb_blocks: r.nb_blocks.try_into().unwrap(),
170                            ..Default::default()
171                        },
172                    },
173                    ..Default::default()
174                },
175                Request::Unsupported(t) => virtio_mem::virtio_mem_req {
176                    type_: t.try_into().unwrap(),
177                    ..Default::default()
178                },
179            }
180        }
181    }
182
183    impl From<virtio_mem::virtio_mem_resp> for Response {
184        fn from(resp: virtio_mem::virtio_mem_resp) -> Self {
185            Response {
186                resp_type: match resp.type_.into() {
187                    virtio_mem::VIRTIO_MEM_RESP_ACK => ResponseType::Ack,
188                    virtio_mem::VIRTIO_MEM_RESP_NACK => ResponseType::Nack,
189                    virtio_mem::VIRTIO_MEM_RESP_BUSY => ResponseType::Busy,
190                    virtio_mem::VIRTIO_MEM_RESP_ERROR => ResponseType::Error,
191                    t => panic!("Invalid response type: {:?}", t),
192                },
193                // There is no way to know whether this is present or not as it depends on the
194                // request types. Callers should ignore this value if the request wasn't STATE
195                /// SAFETY: test code only. Uninitialized values are 0 and recognized as PLUGGED.
196                state: Some(unsafe {
197                    match resp.u.state.state.into() {
198                        virtio_mem::VIRTIO_MEM_STATE_PLUGGED => BlockRangeState::Plugged,
199                        virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED => BlockRangeState::Unplugged,
200                        virtio_mem::VIRTIO_MEM_STATE_MIXED => BlockRangeState::Mixed,
201                        t => panic!("Invalid state: {:?}", t),
202                    }
203                }),
204            }
205        }
206    }
207}