vmm/devices/virtio/balloon/
device.rs

1// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::ops::Deref;
5use std::sync::Arc;
6use std::time::Duration;
7
8use log::{error, info, warn};
9use serde::{Deserialize, Serialize};
10use timerfd::{ClockId, SetTimeFlags, TimerFd, TimerState};
11use vmm_sys_util::eventfd::EventFd;
12
13use super::super::ActivateError;
14use super::super::device::{DeviceState, VirtioDevice};
15use super::super::queue::Queue;
16use super::metrics::METRICS;
17use super::util::compact_page_frame_numbers;
18use super::{
19    BALLOON_DEV_ID, BALLOON_MIN_NUM_QUEUES, BALLOON_QUEUE_SIZE, DEFLATE_INDEX, FREE_PAGE_HINT_DONE,
20    FREE_PAGE_HINT_STOP, INFLATE_INDEX, MAX_PAGE_COMPACT_BUFFER, MAX_PAGES_IN_DESC,
21    MIB_TO_4K_PAGES, STATS_INDEX, VIRTIO_BALLOON_F_DEFLATE_ON_OOM,
22    VIRTIO_BALLOON_F_FREE_PAGE_HINTING, VIRTIO_BALLOON_F_FREE_PAGE_REPORTING,
23    VIRTIO_BALLOON_F_STATS_VQ, VIRTIO_BALLOON_PFN_SHIFT, VIRTIO_BALLOON_S_ALLOC_STALL,
24    VIRTIO_BALLOON_S_ASYNC_RECLAIM, VIRTIO_BALLOON_S_ASYNC_SCAN, VIRTIO_BALLOON_S_AVAIL,
25    VIRTIO_BALLOON_S_CACHES, VIRTIO_BALLOON_S_DIRECT_RECLAIM, VIRTIO_BALLOON_S_DIRECT_SCAN,
26    VIRTIO_BALLOON_S_HTLB_PGALLOC, VIRTIO_BALLOON_S_HTLB_PGFAIL, VIRTIO_BALLOON_S_MAJFLT,
27    VIRTIO_BALLOON_S_MEMFREE, VIRTIO_BALLOON_S_MEMTOT, VIRTIO_BALLOON_S_MINFLT,
28    VIRTIO_BALLOON_S_OOM_KILL, VIRTIO_BALLOON_S_SWAP_IN, VIRTIO_BALLOON_S_SWAP_OUT,
29};
30use crate::devices::virtio::balloon::BalloonError;
31use crate::devices::virtio::device::ActiveState;
32use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1;
33use crate::devices::virtio::generated::virtio_ids::VIRTIO_ID_BALLOON;
34use crate::devices::virtio::queue::InvalidAvailIdx;
35use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType};
36use crate::logger::{IncMetric, log_dev_preview_warning};
37use crate::utils::u64_to_usize;
38use crate::vstate::memory::{
39    Address, ByteValued, Bytes, GuestAddress, GuestMemoryExtension, GuestMemoryMmap,
40};
41use crate::{impl_device_type, mem_size_mib};
42
43const SIZE_OF_U32: usize = std::mem::size_of::<u32>();
44const SIZE_OF_STAT: usize = std::mem::size_of::<BalloonStat>();
45
46fn mib_to_pages(amount_mib: u32) -> Result<u32, BalloonError> {
47    amount_mib
48        .checked_mul(MIB_TO_4K_PAGES)
49        .ok_or(BalloonError::TooMuchMemoryRequested(
50            u32::MAX / MIB_TO_4K_PAGES,
51        ))
52}
53
54fn pages_to_mib(amount_pages: u32) -> u32 {
55    amount_pages / MIB_TO_4K_PAGES
56}
57
58#[repr(C)]
59#[derive(Clone, Copy, Debug, Default, PartialEq)]
60pub(crate) struct ConfigSpace {
61    pub num_pages: u32,
62    pub actual_pages: u32,
63    pub free_page_hint_cmd_id: u32,
64}
65
66// SAFETY: Safe because ConfigSpace only contains plain data.
67unsafe impl ByteValued for ConfigSpace {}
68
69/// Holds state of the free page hinting run
70#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize)]
71pub(crate) struct HintingState {
72    /// The command requested by us. Set to STOP by default.
73    pub host_cmd: u32,
74    /// The last command supplied by guest.
75    pub last_cmd_id: u32,
76    /// The command supplied by guest.
77    pub guest_cmd: Option<u32>,
78    /// Whether or not to automatically ack on STOP.
79    pub acknowledge_on_finish: bool,
80}
81
82/// By default hinting will ack on stop
83fn default_ack_on_stop() -> bool {
84    true
85}
86
87/// Command recieved from the API to start a hinting run
88#[derive(Copy, Clone, Debug, Eq, PartialEq, Deserialize)]
89pub struct StartHintingCmd {
90    /// If we should automatically acknowledge end of the run after stop.
91    #[serde(default = "default_ack_on_stop")]
92    pub acknowledge_on_stop: bool,
93}
94
95impl Default for StartHintingCmd {
96    fn default() -> Self {
97        Self {
98            acknowledge_on_stop: true,
99        }
100    }
101}
102
103/// Returned to the API for get hinting status
104#[derive(Copy, Clone, Debug, Eq, PartialEq, Default, Serialize)]
105pub struct HintingStatus {
106    /// The command requested by us. Set to STOP by default.
107    pub host_cmd: u32,
108    /// The command supplied by guest.
109    pub guest_cmd: Option<u32>,
110}
111
112// This structure needs the `packed` attribute, otherwise Rust will assume
113// the size to be 16 bytes.
114#[derive(Copy, Clone, Debug, Default)]
115#[repr(C, packed)]
116struct BalloonStat {
117    pub tag: u16,
118    pub val: u64,
119}
120
121// SAFETY: Safe because BalloonStat only contains plain data.
122unsafe impl ByteValued for BalloonStat {}
123
124/// Holds configuration details for the balloon device.
125#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize)]
126pub struct BalloonConfig {
127    /// Target size.
128    pub amount_mib: u32,
129    /// Whether or not to ask for pages back.
130    pub deflate_on_oom: bool,
131    /// Interval of time in seconds at which the balloon statistics are updated.
132    pub stats_polling_interval_s: u16,
133    /// Free page hinting enabled
134    #[serde(default)]
135    pub free_page_hinting: bool,
136    /// Free page reporting enabled
137    #[serde(default)]
138    pub free_page_reporting: bool,
139}
140
141/// BalloonStats holds statistics returned from the stats_queue.
142#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Serialize)]
143#[serde(deny_unknown_fields)]
144pub struct BalloonStats {
145    /// The target size of the balloon, in 4K pages.
146    pub target_pages: u32,
147    /// The number of 4K pages the device is currently holding.
148    pub actual_pages: u32,
149    /// The target size of the balloon, in MiB.
150    pub target_mib: u32,
151    /// The number of MiB the device is currently holding.
152    pub actual_mib: u32,
153    /// Amount of memory swapped in.
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub swap_in: Option<u64>,
156    /// Amount of memory swapped out.
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub swap_out: Option<u64>,
159    /// Number of major faults.
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub major_faults: Option<u64>,
162    /// Number of minor faults.
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub minor_faults: Option<u64>,
165    /// The amount of memory not being used for any
166    /// purpose (in bytes).
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub free_memory: Option<u64>,
169    /// Total amount of memory available (in bytes).
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub total_memory: Option<u64>,
172    /// An estimate of how much memory is available (in
173    /// bytes) for starting new applications, without pushing the system to swap.
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub available_memory: Option<u64>,
176    /// The amount of memory, in bytes, that can be
177    /// quickly reclaimed without additional I/O. Typically these pages are used for
178    /// caching files from disk.
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub disk_caches: Option<u64>,
181    /// The number of successful hugetlb page
182    /// allocations in the guest.
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub hugetlb_allocations: Option<u64>,
185    /// The number of failed hugetlb page allocations
186    /// in the guest.
187    #[serde(skip_serializing_if = "Option::is_none")]
188    pub hugetlb_failures: Option<u64>,
189    /// OOM killer invocations. since linux v6.12.
190    #[serde(skip_serializing_if = "Option::is_none")]
191    pub oom_kill: Option<u64>,
192    /// Stall count of memory allocatoin. since linux v6.12.
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub alloc_stall: Option<u64>,
195    /// Amount of memory scanned asynchronously. since linux v6.12.
196    #[serde(skip_serializing_if = "Option::is_none")]
197    pub async_scan: Option<u64>,
198    /// Amount of memory scanned directly. since linux v6.12.
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub direct_scan: Option<u64>,
201    /// Amount of memory reclaimed asynchronously. since linux v6.12.
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub async_reclaim: Option<u64>,
204    /// Amount of memory reclaimed directly. since linux v6.12.
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub direct_reclaim: Option<u64>,
207}
208
209impl BalloonStats {
210    fn update_with_stat(&mut self, stat: &BalloonStat) -> Result<(), BalloonError> {
211        let val = Some(stat.val);
212        match stat.tag {
213            VIRTIO_BALLOON_S_SWAP_IN => self.swap_in = val,
214            VIRTIO_BALLOON_S_SWAP_OUT => self.swap_out = val,
215            VIRTIO_BALLOON_S_MAJFLT => self.major_faults = val,
216            VIRTIO_BALLOON_S_MINFLT => self.minor_faults = val,
217            VIRTIO_BALLOON_S_MEMFREE => self.free_memory = val,
218            VIRTIO_BALLOON_S_MEMTOT => self.total_memory = val,
219            VIRTIO_BALLOON_S_AVAIL => self.available_memory = val,
220            VIRTIO_BALLOON_S_CACHES => self.disk_caches = val,
221            VIRTIO_BALLOON_S_HTLB_PGALLOC => self.hugetlb_allocations = val,
222            VIRTIO_BALLOON_S_HTLB_PGFAIL => self.hugetlb_failures = val,
223            VIRTIO_BALLOON_S_OOM_KILL => self.oom_kill = val,
224            VIRTIO_BALLOON_S_ALLOC_STALL => self.alloc_stall = val,
225            VIRTIO_BALLOON_S_ASYNC_SCAN => self.async_scan = val,
226            VIRTIO_BALLOON_S_DIRECT_SCAN => self.direct_scan = val,
227            VIRTIO_BALLOON_S_ASYNC_RECLAIM => self.async_reclaim = val,
228            VIRTIO_BALLOON_S_DIRECT_RECLAIM => self.direct_reclaim = val,
229            _ => {
230                return Err(BalloonError::MalformedPayload);
231            }
232        }
233
234        Ok(())
235    }
236}
237
238/// Virtio balloon device.
239#[derive(Debug)]
240pub struct Balloon {
241    // Virtio fields.
242    pub(crate) avail_features: u64,
243    pub(crate) acked_features: u64,
244    pub(crate) config_space: ConfigSpace,
245    pub(crate) activate_evt: EventFd,
246
247    // Transport related fields.
248    pub(crate) queues: Vec<Queue>,
249    pub(crate) queue_evts: Vec<EventFd>,
250    pub(crate) device_state: DeviceState,
251
252    // Implementation specific fields.
253    pub(crate) stats_polling_interval_s: u16,
254    pub(crate) stats_timer: TimerFd,
255    // The index of the previous stats descriptor is saved because
256    // it is acknowledged after the stats queue is processed.
257    pub(crate) stats_desc_index: Option<u16>,
258    pub(crate) latest_stats: BalloonStats,
259    // A buffer used as pfn accumulator during descriptor processing.
260    pub(crate) pfn_buffer: [u32; MAX_PAGE_COMPACT_BUFFER],
261
262    // Holds state for free page hinting
263    pub(crate) hinting_state: HintingState,
264}
265
266impl Balloon {
267    /// Instantiate a new balloon device.
268    pub fn new(
269        amount_mib: u32,
270        deflate_on_oom: bool,
271        stats_polling_interval_s: u16,
272        free_page_hinting: bool,
273        free_page_reporting: bool,
274    ) -> Result<Balloon, BalloonError> {
275        let mut avail_features = 1u64 << VIRTIO_F_VERSION_1;
276
277        if deflate_on_oom {
278            avail_features |= 1u64 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM;
279        };
280
281        // The VirtIO specification states that the statistics queue should
282        // not be present at all if the statistics are not enabled.
283        let mut queue_count = BALLOON_MIN_NUM_QUEUES;
284        if stats_polling_interval_s > 0 {
285            avail_features |= 1u64 << VIRTIO_BALLOON_F_STATS_VQ;
286            queue_count += 1;
287        }
288
289        if free_page_hinting {
290            log_dev_preview_warning("Free Page Hinting", None);
291            avail_features |= 1u64 << VIRTIO_BALLOON_F_FREE_PAGE_HINTING;
292            queue_count += 1;
293        }
294
295        if free_page_reporting {
296            avail_features |= 1u64 << VIRTIO_BALLOON_F_FREE_PAGE_REPORTING;
297            queue_count += 1;
298        }
299
300        let queues: Vec<Queue> = (0..queue_count)
301            .map(|_| Queue::new(BALLOON_QUEUE_SIZE))
302            .collect();
303        let queue_evts = (0..queue_count)
304            .map(|_| EventFd::new(libc::EFD_NONBLOCK).map_err(BalloonError::EventFd))
305            .collect::<Result<Vec<_>, _>>()?;
306
307        let stats_timer =
308            TimerFd::new_custom(ClockId::Monotonic, true, true).map_err(BalloonError::Timer)?;
309
310        Ok(Balloon {
311            avail_features,
312            acked_features: 0u64,
313            config_space: ConfigSpace {
314                num_pages: mib_to_pages(amount_mib)?,
315                actual_pages: 0,
316                free_page_hint_cmd_id: FREE_PAGE_HINT_STOP,
317            },
318            queue_evts,
319            queues,
320            device_state: DeviceState::Inactive,
321            activate_evt: EventFd::new(libc::EFD_NONBLOCK).map_err(BalloonError::EventFd)?,
322            stats_polling_interval_s,
323            stats_timer,
324            stats_desc_index: None,
325            latest_stats: BalloonStats::default(),
326            pfn_buffer: [0u32; MAX_PAGE_COMPACT_BUFFER],
327            hinting_state: Default::default(),
328        })
329    }
330
331    pub(crate) fn process_inflate_queue_event(&mut self) -> Result<(), BalloonError> {
332        self.queue_evts[INFLATE_INDEX]
333            .read()
334            .map_err(BalloonError::EventFd)?;
335        self.process_inflate()
336    }
337
338    pub(crate) fn process_deflate_queue_event(&mut self) -> Result<(), BalloonError> {
339        self.queue_evts[DEFLATE_INDEX]
340            .read()
341            .map_err(BalloonError::EventFd)?;
342        self.process_deflate_queue()
343    }
344
345    pub(crate) fn process_stats_queue_event(&mut self) -> Result<(), BalloonError> {
346        self.queue_evts[STATS_INDEX]
347            .read()
348            .map_err(BalloonError::EventFd)?;
349        self.process_stats_queue()
350    }
351
352    pub(crate) fn process_stats_timer_event(&mut self) -> Result<(), BalloonError> {
353        self.stats_timer.read();
354        self.trigger_stats_update()
355    }
356
357    pub(crate) fn process_free_page_hinting_queue_event(&mut self) -> Result<(), BalloonError> {
358        self.queue_evts[self.free_page_hinting_idx()]
359            .read()
360            .map_err(BalloonError::EventFd)?;
361        self.process_free_page_hinting_queue()
362    }
363
364    pub(crate) fn process_free_page_reporting_queue_event(&mut self) -> Result<(), BalloonError> {
365        self.queue_evts[self.free_page_reporting_idx()]
366            .read()
367            .map_err(BalloonError::EventFd)?;
368        self.process_free_page_reporting_queue()
369    }
370
371    pub(crate) fn process_inflate(&mut self) -> Result<(), BalloonError> {
372        // This is safe since we checked in the event handler that the device is activated.
373        let mem = &self
374            .device_state
375            .active_state()
376            .ok_or(BalloonError::DeviceNotActive)?
377            .mem;
378        METRICS.inflate_count.inc();
379
380        let queue = &mut self.queues[INFLATE_INDEX];
381        // The pfn buffer index used during descriptor processing.
382        let mut pfn_buffer_idx = 0;
383        let mut needs_interrupt = false;
384        let mut valid_descs_found = true;
385
386        // Loop until there are no more valid DescriptorChains.
387        while valid_descs_found {
388            valid_descs_found = false;
389            // Internal loop processes descriptors and acummulates the pfns in `pfn_buffer`.
390            // Breaks out when there is not enough space in `pfn_buffer` to completely process
391            // the next descriptor.
392            while let Some(head) = queue.pop()? {
393                let len = head.len as usize;
394                let max_len = MAX_PAGES_IN_DESC * SIZE_OF_U32;
395                valid_descs_found = true;
396
397                if !head.is_write_only() && len % SIZE_OF_U32 == 0 {
398                    // Check descriptor pfn count.
399                    if len > max_len {
400                        error!(
401                            "Inflate descriptor has bogus page count {} > {}, skipping.",
402                            len / SIZE_OF_U32,
403                            MAX_PAGES_IN_DESC
404                        );
405
406                        // Skip descriptor.
407                        continue;
408                    }
409                    // Break loop if `pfn_buffer` will be overrun by adding all pfns from current
410                    // desc.
411                    if MAX_PAGE_COMPACT_BUFFER - pfn_buffer_idx < len / SIZE_OF_U32 {
412                        queue.undo_pop();
413                        break;
414                    }
415
416                    // This is safe, `len` was validated above.
417                    for index in (0..len).step_by(SIZE_OF_U32) {
418                        let addr = head
419                            .addr
420                            .checked_add(index as u64)
421                            .ok_or(BalloonError::MalformedDescriptor)?;
422
423                        let page_frame_number = mem
424                            .read_obj::<u32>(addr)
425                            .map_err(|_| BalloonError::MalformedDescriptor)?;
426
427                        self.pfn_buffer[pfn_buffer_idx] = page_frame_number;
428                        pfn_buffer_idx += 1;
429                    }
430                }
431
432                // Acknowledge the receipt of the descriptor.
433                // 0 is number of bytes the device has written to memory.
434                queue.add_used(head.index, 0)?;
435                needs_interrupt = true;
436            }
437
438            // Compact pages into ranges.
439            let page_ranges = compact_page_frame_numbers(&mut self.pfn_buffer[..pfn_buffer_idx]);
440            pfn_buffer_idx = 0;
441
442            // Remove the page ranges.
443            for (page_frame_number, range_len) in page_ranges {
444                let guest_addr =
445                    GuestAddress(u64::from(page_frame_number) << VIRTIO_BALLOON_PFN_SHIFT);
446
447                if let Err(err) = mem.discard_range(
448                    guest_addr,
449                    usize::try_from(range_len).unwrap() << VIRTIO_BALLOON_PFN_SHIFT,
450                ) {
451                    error!("Error removing memory range: {:?}", err);
452                }
453            }
454        }
455        queue.advance_used_ring_idx();
456
457        if needs_interrupt {
458            self.signal_used_queue(INFLATE_INDEX)?;
459        }
460
461        Ok(())
462    }
463
464    pub(crate) fn process_deflate_queue(&mut self) -> Result<(), BalloonError> {
465        METRICS.deflate_count.inc();
466
467        let queue = &mut self.queues[DEFLATE_INDEX];
468        let mut needs_interrupt = false;
469
470        while let Some(head) = queue.pop()? {
471            queue.add_used(head.index, 0)?;
472            needs_interrupt = true;
473        }
474        queue.advance_used_ring_idx();
475
476        if needs_interrupt {
477            self.signal_used_queue(DEFLATE_INDEX)
478        } else {
479            Ok(())
480        }
481    }
482
483    pub(crate) fn process_stats_queue(&mut self) -> Result<(), BalloonError> {
484        // This is safe since we checked in the event handler that the device is activated.
485        let mem = &self.device_state.active_state().unwrap().mem;
486        METRICS.stats_updates_count.inc();
487
488        while let Some(head) = self.queues[STATS_INDEX].pop()? {
489            if let Some(prev_stats_desc) = self.stats_desc_index {
490                // We shouldn't ever have an extra buffer if the driver follows
491                // the protocol, but return it if we find one.
492                error!("balloon: driver is not compliant, more than one stats buffer received");
493                self.queues[STATS_INDEX].add_used(prev_stats_desc, 0)?;
494            }
495            for index in (0..head.len).step_by(SIZE_OF_STAT) {
496                // Read the address at position `index`. The only case
497                // in which this fails is if there is overflow,
498                // in which case this descriptor is malformed,
499                // so we ignore the rest of it.
500                let addr = head
501                    .addr
502                    .checked_add(u64::from(index))
503                    .ok_or(BalloonError::MalformedDescriptor)?;
504                let stat = mem
505                    .read_obj::<BalloonStat>(addr)
506                    .map_err(|_| BalloonError::MalformedDescriptor)?;
507                self.latest_stats.update_with_stat(&stat).map_err(|_| {
508                    METRICS.stats_update_fails.inc();
509                    BalloonError::MalformedPayload
510                })?;
511            }
512
513            self.stats_desc_index = Some(head.index);
514        }
515
516        Ok(())
517    }
518
519    pub(crate) fn process_free_page_hinting_queue(&mut self) -> Result<(), BalloonError> {
520        let mem = &self
521            .device_state
522            .active_state()
523            .ok_or(BalloonError::DeviceNotActive)?
524            .mem;
525
526        let idx = self.free_page_hinting_idx();
527        let queue = &mut self.queues[idx];
528        let host_cmd = self.hinting_state.host_cmd;
529        let mut needs_interrupt = false;
530        let mut complete = false;
531
532        while let Some(head) = queue.pop()? {
533            let head_index = head.index;
534
535            let mut last_desc = Some(head);
536            while let Some(desc) = last_desc {
537                last_desc = desc.next_descriptor();
538
539                // Updated cmd_ids are always of length 4
540                if desc.len == 4 {
541                    complete = false;
542
543                    let cmd = mem
544                        .read_obj::<u32>(desc.addr)
545                        .map_err(|_| BalloonError::MalformedDescriptor)?;
546                    self.hinting_state.guest_cmd = Some(cmd);
547                    if cmd == FREE_PAGE_HINT_STOP {
548                        complete = true;
549                    }
550
551                    // We don't expect this from the driver, but lets treat as a stop
552                    if cmd == FREE_PAGE_HINT_DONE {
553                        warn!("balloon hinting: Unexpected cmd from guest: {cmd}");
554                        complete = true;
555                    }
556
557                    continue;
558                }
559
560                // If we've requested done we have to discard any in-flight hints
561                if host_cmd == FREE_PAGE_HINT_DONE || host_cmd == FREE_PAGE_HINT_STOP {
562                    continue;
563                }
564
565                let Some(chain_cmd) = self.hinting_state.guest_cmd else {
566                    warn!("balloon hinting: received range with no command id.");
567                    continue;
568                };
569
570                if chain_cmd != host_cmd {
571                    info!("balloon hinting: Received chain from previous command ignoring.");
572                    continue;
573                }
574
575                METRICS.free_page_hint_count.inc();
576                if let Err(err) = mem.discard_range(desc.addr, desc.len as usize) {
577                    METRICS.free_page_hint_fails.inc();
578                    error!("balloon hinting: failed to remove range: {err:?}");
579                } else {
580                    METRICS.free_page_hint_freed.add(desc.len as u64);
581                }
582            }
583
584            queue.add_used(head.index, 0)?;
585            needs_interrupt = true;
586        }
587
588        queue.advance_used_ring_idx();
589
590        if needs_interrupt {
591            self.signal_used_queue(idx)?;
592        }
593
594        if complete && self.hinting_state.acknowledge_on_finish {
595            self.update_free_page_hint_cmd(FREE_PAGE_HINT_DONE);
596        }
597
598        Ok(())
599    }
600
601    pub(crate) fn process_free_page_reporting_queue(&mut self) -> Result<(), BalloonError> {
602        let mem = &self
603            .device_state
604            .active_state()
605            .ok_or(BalloonError::DeviceNotActive)?
606            .mem;
607
608        let idx = self.free_page_reporting_idx();
609        let queue = &mut self.queues[idx];
610        let mut needs_interrupt = false;
611
612        while let Some(head) = queue.pop()? {
613            let head_index = head.index;
614
615            let mut last_desc = Some(head);
616            while let Some(desc) = last_desc {
617                METRICS.free_page_report_count.inc();
618                if let Err(err) = mem.discard_range(desc.addr, desc.len as usize) {
619                    METRICS.free_page_report_fails.inc();
620                    error!("balloon: failed to remove range: {err:?}");
621                } else {
622                    METRICS.free_page_report_freed.add(desc.len as u64);
623                }
624                last_desc = desc.next_descriptor();
625            }
626
627            queue.add_used(head.index, 0)?;
628            needs_interrupt = true;
629        }
630
631        queue.advance_used_ring_idx();
632
633        if needs_interrupt {
634            self.signal_used_queue(idx)?;
635        }
636
637        Ok(())
638    }
639
640    pub(crate) fn signal_used_queue(&self, qidx: usize) -> Result<(), BalloonError> {
641        self.interrupt_trigger()
642            .trigger(VirtioInterruptType::Queue(
643                qidx.try_into()
644                    .unwrap_or_else(|_| panic!("balloon: invalid queue id: {qidx}")),
645            ))
646            .map_err(|err| {
647                METRICS.event_fails.inc();
648                BalloonError::InterruptError(err)
649            })
650    }
651
652    /// Process device virtio queue(s).
653    pub fn process_virtio_queues(&mut self) -> Result<(), InvalidAvailIdx> {
654        if let Err(BalloonError::InvalidAvailIdx(err)) = self.process_inflate() {
655            return Err(err);
656        }
657        if let Err(BalloonError::InvalidAvailIdx(err)) = self.process_deflate_queue() {
658            return Err(err);
659        }
660
661        if self.free_page_hinting()
662            && let Err(BalloonError::InvalidAvailIdx(err)) = self.process_free_page_hinting_queue()
663        {
664            return Err(err);
665        }
666
667        if self.free_page_reporting()
668            && let Err(BalloonError::InvalidAvailIdx(err)) =
669                self.process_free_page_reporting_queue()
670        {
671            return Err(err);
672        }
673
674        Ok(())
675    }
676
677    /// Provides the ID of this balloon device.
678    pub fn id(&self) -> &str {
679        BALLOON_DEV_ID
680    }
681
682    fn trigger_stats_update(&mut self) -> Result<(), BalloonError> {
683        // The communication is driven by the device by using the buffer
684        // and sending a used buffer notification
685        if let Some(index) = self.stats_desc_index.take() {
686            self.queues[STATS_INDEX].add_used(index, 0)?;
687            self.queues[STATS_INDEX].advance_used_ring_idx();
688            self.signal_used_queue(STATS_INDEX)
689        } else {
690            error!("Failed to update balloon stats, missing descriptor.");
691            Ok(())
692        }
693    }
694
695    /// Update the target size of the balloon.
696    pub fn update_size(&mut self, amount_mib: u32) -> Result<(), BalloonError> {
697        if self.is_activated() {
698            let mem = &self.device_state.active_state().unwrap().mem;
699            // The balloon cannot have a target size greater than the size of
700            // the guest memory.
701            if u64::from(amount_mib) > mem_size_mib(mem) {
702                return Err(BalloonError::TooMuchMemoryRequested(amount_mib));
703            }
704            self.config_space.num_pages = mib_to_pages(amount_mib)?;
705            self.interrupt_trigger()
706                .trigger(VirtioInterruptType::Config)
707                .map_err(BalloonError::InterruptError)
708        } else {
709            Err(BalloonError::DeviceNotActive)
710        }
711    }
712
713    pub fn free_page_hinting(&self) -> bool {
714        self.avail_features & (1u64 << VIRTIO_BALLOON_F_FREE_PAGE_HINTING) != 0
715    }
716
717    pub fn free_page_hinting_idx(&self) -> usize {
718        let mut idx = BALLOON_MIN_NUM_QUEUES;
719
720        if self.stats_polling_interval_s > 0 {
721            idx += 1;
722        }
723
724        idx
725    }
726
727    pub fn free_page_reporting(&self) -> bool {
728        self.avail_features & (1u64 << VIRTIO_BALLOON_F_FREE_PAGE_REPORTING) != 0
729    }
730
731    pub fn free_page_reporting_idx(&self) -> usize {
732        let mut idx = BALLOON_MIN_NUM_QUEUES;
733
734        if self.stats_polling_interval_s > 0 {
735            idx += 1;
736        }
737
738        if self.free_page_hinting() {
739            idx += 1;
740        }
741
742        idx
743    }
744
745    /// Update the statistics polling interval.
746    pub fn update_stats_polling_interval(&mut self, interval_s: u16) -> Result<(), BalloonError> {
747        if self.stats_polling_interval_s == interval_s {
748            return Ok(());
749        }
750
751        if self.stats_polling_interval_s == 0 || interval_s == 0 {
752            return Err(BalloonError::StatisticsStateChange);
753        }
754
755        self.trigger_stats_update()?;
756
757        self.stats_polling_interval_s = interval_s;
758        self.update_timer_state();
759        Ok(())
760    }
761
762    pub fn update_timer_state(&mut self) {
763        let timer_state = TimerState::Periodic {
764            current: Duration::from_secs(u64::from(self.stats_polling_interval_s)),
765            interval: Duration::from_secs(u64::from(self.stats_polling_interval_s)),
766        };
767        self.stats_timer
768            .set_state(timer_state, SetTimeFlags::Default);
769    }
770
771    /// Obtain the number of 4K pages the device is currently holding.
772    pub fn num_pages(&self) -> u32 {
773        self.config_space.num_pages
774    }
775
776    /// Obtain the size of 4K pages the device is currently holding in MIB.
777    pub fn size_mb(&self) -> u32 {
778        pages_to_mib(self.config_space.num_pages)
779    }
780
781    pub fn deflate_on_oom(&self) -> bool {
782        self.avail_features & (1u64 << VIRTIO_BALLOON_F_DEFLATE_ON_OOM) != 0
783    }
784
785    pub fn stats_polling_interval_s(&self) -> u16 {
786        self.stats_polling_interval_s
787    }
788
789    /// Retrieve latest stats for the balloon device.
790    pub fn latest_stats(&mut self) -> Result<BalloonStats, BalloonError> {
791        if self.stats_enabled() {
792            self.latest_stats.target_pages = self.config_space.num_pages;
793            self.latest_stats.actual_pages = self.config_space.actual_pages;
794            self.latest_stats.target_mib = pages_to_mib(self.latest_stats.target_pages);
795            self.latest_stats.actual_mib = pages_to_mib(self.latest_stats.actual_pages);
796            Ok(self.latest_stats)
797        } else {
798            Err(BalloonError::StatisticsDisabled)
799        }
800    }
801
802    /// Update the free page hinting cmd
803    pub fn update_free_page_hint_cmd(&mut self, cmd_id: u32) -> Result<(), BalloonError> {
804        if !self.is_activated() {
805            return Err(BalloonError::DeviceNotActive);
806        }
807
808        self.hinting_state.host_cmd = cmd_id;
809        self.config_space.free_page_hint_cmd_id = cmd_id;
810        self.interrupt_trigger()
811            .trigger(VirtioInterruptType::Config)
812            .map_err(BalloonError::InterruptError)
813    }
814
815    /// Starts a hinting run by setting the cmd_id to a new value.
816    pub(crate) fn start_hinting(&mut self, cmd: StartHintingCmd) -> Result<(), BalloonError> {
817        if !self.free_page_hinting() {
818            return Err(BalloonError::HintingNotEnabled);
819        }
820
821        let mut cmd_id = self.hinting_state.last_cmd_id.wrapping_add(1);
822        // 0 and 1 are reserved and cannot be used to start a hinting run
823        if cmd_id <= 1 {
824            cmd_id = 2;
825        }
826
827        self.hinting_state.acknowledge_on_finish = cmd.acknowledge_on_stop;
828        self.hinting_state.last_cmd_id = cmd_id;
829        self.update_free_page_hint_cmd(cmd_id)
830    }
831
832    /// Return the status of the hinting including the last command we sent to the driver
833    /// and the last cmd sent from the driver
834    pub(crate) fn get_hinting_status(&self) -> Result<HintingStatus, BalloonError> {
835        if !self.free_page_hinting() {
836            return Err(BalloonError::HintingNotEnabled);
837        }
838
839        Ok(HintingStatus {
840            host_cmd: self.hinting_state.host_cmd,
841            guest_cmd: self.hinting_state.guest_cmd,
842        })
843    }
844
845    /// Stops the hinting run allowing the guest to reclaim hinted pages
846    pub(crate) fn stop_hinting(&mut self) -> Result<(), BalloonError> {
847        if !self.free_page_hinting() {
848            Err(BalloonError::HintingNotEnabled)
849        } else {
850            self.update_free_page_hint_cmd(FREE_PAGE_HINT_DONE)
851        }
852    }
853
854    /// Return the config of the balloon device.
855    pub fn config(&self) -> BalloonConfig {
856        BalloonConfig {
857            amount_mib: self.size_mb(),
858            deflate_on_oom: self.deflate_on_oom(),
859            stats_polling_interval_s: self.stats_polling_interval_s(),
860            free_page_hinting: self.free_page_hinting(),
861            free_page_reporting: self.free_page_reporting(),
862        }
863    }
864
865    pub(crate) fn stats_enabled(&self) -> bool {
866        self.stats_polling_interval_s > 0
867    }
868
869    pub(crate) fn set_stats_desc_index(&mut self, stats_desc_index: Option<u16>) {
870        self.stats_desc_index = stats_desc_index;
871    }
872}
873
874impl VirtioDevice for Balloon {
875    impl_device_type!(VIRTIO_ID_BALLOON);
876
877    fn avail_features(&self) -> u64 {
878        self.avail_features
879    }
880
881    fn acked_features(&self) -> u64 {
882        self.acked_features
883    }
884
885    fn set_acked_features(&mut self, acked_features: u64) {
886        self.acked_features = acked_features;
887    }
888
889    fn queues(&self) -> &[Queue] {
890        &self.queues
891    }
892
893    fn queues_mut(&mut self) -> &mut [Queue] {
894        &mut self.queues
895    }
896
897    fn queue_events(&self) -> &[EventFd] {
898        &self.queue_evts
899    }
900
901    fn interrupt_trigger(&self) -> &dyn VirtioInterrupt {
902        self.device_state
903            .active_state()
904            .expect("Device is not activated")
905            .interrupt
906            .deref()
907    }
908
909    fn read_config(&self, offset: u64, data: &mut [u8]) {
910        if let Some(config_space_bytes) = self.config_space.as_slice().get(u64_to_usize(offset)..) {
911            let len = config_space_bytes.len().min(data.len());
912            data[..len].copy_from_slice(&config_space_bytes[..len]);
913        } else {
914            error!("Failed to read config space");
915        }
916    }
917
918    fn write_config(&mut self, offset: u64, data: &[u8]) {
919        let config_space_bytes = self.config_space.as_mut_slice();
920        let start = usize::try_from(offset).ok();
921        let end = start.and_then(|s| s.checked_add(data.len()));
922        let Some(dst) = start
923            .zip(end)
924            .and_then(|(start, end)| config_space_bytes.get_mut(start..end))
925        else {
926            error!("Failed to write config space");
927            return;
928        };
929
930        dst.copy_from_slice(data);
931    }
932
933    fn activate(
934        &mut self,
935        mem: GuestMemoryMmap,
936        interrupt: Arc<dyn VirtioInterrupt>,
937    ) -> Result<(), ActivateError> {
938        for q in self.queues.iter_mut() {
939            q.initialize(&mem)
940                .map_err(ActivateError::QueueMemoryError)?;
941        }
942
943        self.device_state = DeviceState::Activated(ActiveState { mem, interrupt });
944        if self.activate_evt.write(1).is_err() {
945            METRICS.activate_fails.inc();
946            self.device_state = DeviceState::Inactive;
947            return Err(ActivateError::EventFd);
948        }
949
950        if self.stats_enabled() {
951            self.update_timer_state();
952        }
953
954        // On activate ensure hint cmd is reset to FREE_PAGE_HINT_DONE
955        if self.is_activated() && self.free_page_hinting() {
956            self.update_free_page_hint_cmd(FREE_PAGE_HINT_DONE)
957                .map_err(|_| ActivateError::EventFd)?;
958        }
959
960        Ok(())
961    }
962
963    fn is_activated(&self) -> bool {
964        self.device_state.is_activated()
965    }
966
967    fn kick(&mut self) {
968        // If device is activated, kick the balloon queue(s) to make up for any
969        // pending or in-flight epoll events we may have not captured in snapshot.
970        // Stats queue doesn't need kicking as it is notified via a `timer_fd`.
971        if self.is_activated() {
972            info!("kick balloon {}.", self.id());
973            self.process_virtio_queues();
974        }
975    }
976}
977
978#[cfg(test)]
979pub(crate) mod tests {
980    use itertools::iproduct;
981
982    use super::super::BALLOON_CONFIG_SPACE_SIZE;
983    use super::*;
984    use crate::arch::host_page_size;
985    use crate::check_metric_after_block;
986    use crate::devices::virtio::balloon::report_balloon_event_fail;
987    use crate::devices::virtio::balloon::test_utils::{
988        check_request_completion, invoke_handler_for_queue_event, set_request,
989    };
990    use crate::devices::virtio::queue::{VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE};
991    use crate::devices::virtio::test_utils::test::{
992        VirtioTestDevice, VirtioTestHelper, create_virtio_mem,
993    };
994    use crate::devices::virtio::test_utils::{VirtQueue, default_interrupt, default_mem};
995    use crate::test_utils::single_region_mem;
996    use crate::utils::align_up;
997    use crate::vstate::memory::GuestAddress;
998
999    impl VirtioTestDevice for Balloon {
1000        fn set_queues(&mut self, queues: Vec<Queue>) {
1001            self.queues = queues;
1002        }
1003
1004        fn num_queues(&self) -> usize {
1005            let mut idx = STATS_INDEX;
1006
1007            if self.stats_polling_interval_s > 0 {
1008                idx += 1;
1009            }
1010
1011            if self.free_page_hinting() {
1012                idx += 1;
1013            }
1014
1015            if self.free_page_reporting() {
1016                idx += 1;
1017            }
1018
1019            idx
1020        }
1021    }
1022
1023    impl Balloon {
1024        pub(crate) fn set_queue(&mut self, idx: usize, q: Queue) {
1025            self.queues[idx] = q;
1026        }
1027
1028        pub(crate) fn actual_pages(&self) -> u32 {
1029            self.config_space.actual_pages
1030        }
1031
1032        pub fn update_num_pages(&mut self, num_pages: u32) {
1033            self.config_space.num_pages = num_pages;
1034        }
1035
1036        pub fn update_actual_pages(&mut self, actual_pages: u32) {
1037            self.config_space.actual_pages = actual_pages;
1038        }
1039    }
1040
1041    #[test]
1042    fn test_balloon_stat_size() {
1043        assert_eq!(SIZE_OF_STAT, 10);
1044    }
1045
1046    #[test]
1047    fn test_update_balloon_stats() {
1048        // Test all feature combinations.
1049        let mut stats = BalloonStats {
1050            target_pages: 5120,
1051            actual_pages: 2560,
1052            target_mib: 20,
1053            actual_mib: 10,
1054            swap_in: Some(0),
1055            swap_out: Some(0),
1056            major_faults: Some(0),
1057            minor_faults: Some(0),
1058            free_memory: Some(0),
1059            total_memory: Some(0),
1060            available_memory: Some(0),
1061            disk_caches: Some(0),
1062            hugetlb_allocations: Some(0),
1063            hugetlb_failures: Some(0),
1064            oom_kill: None,
1065            alloc_stall: None,
1066            async_scan: None,
1067            direct_scan: None,
1068            async_reclaim: None,
1069            direct_reclaim: None,
1070        };
1071
1072        let mut stat = BalloonStat {
1073            tag: VIRTIO_BALLOON_S_SWAP_IN,
1074            val: 1,
1075        };
1076
1077        stats.update_with_stat(&stat).unwrap();
1078        assert_eq!(stats.swap_in, Some(1));
1079        stat.tag = VIRTIO_BALLOON_S_SWAP_OUT;
1080        stats.update_with_stat(&stat).unwrap();
1081        assert_eq!(stats.swap_out, Some(1));
1082        stat.tag = VIRTIO_BALLOON_S_MAJFLT;
1083        stats.update_with_stat(&stat).unwrap();
1084        assert_eq!(stats.major_faults, Some(1));
1085        stat.tag = VIRTIO_BALLOON_S_MINFLT;
1086        stats.update_with_stat(&stat).unwrap();
1087        assert_eq!(stats.minor_faults, Some(1));
1088        stat.tag = VIRTIO_BALLOON_S_MEMFREE;
1089        stats.update_with_stat(&stat).unwrap();
1090        assert_eq!(stats.free_memory, Some(1));
1091        stat.tag = VIRTIO_BALLOON_S_MEMTOT;
1092        stats.update_with_stat(&stat).unwrap();
1093        assert_eq!(stats.total_memory, Some(1));
1094        stat.tag = VIRTIO_BALLOON_S_AVAIL;
1095        stats.update_with_stat(&stat).unwrap();
1096        assert_eq!(stats.available_memory, Some(1));
1097        stat.tag = VIRTIO_BALLOON_S_CACHES;
1098        stats.update_with_stat(&stat).unwrap();
1099        assert_eq!(stats.disk_caches, Some(1));
1100        stat.tag = VIRTIO_BALLOON_S_HTLB_PGALLOC;
1101        stats.update_with_stat(&stat).unwrap();
1102        assert_eq!(stats.hugetlb_allocations, Some(1));
1103        stat.tag = VIRTIO_BALLOON_S_HTLB_PGFAIL;
1104        stats.update_with_stat(&stat).unwrap();
1105        assert_eq!(stats.hugetlb_failures, Some(1));
1106        stat.tag = VIRTIO_BALLOON_S_OOM_KILL;
1107        stats.update_with_stat(&stat).unwrap();
1108        assert_eq!(stats.oom_kill, Some(1));
1109        stat.tag = VIRTIO_BALLOON_S_ALLOC_STALL;
1110        stats.update_with_stat(&stat).unwrap();
1111        assert_eq!(stats.alloc_stall, Some(1));
1112        stat.tag = VIRTIO_BALLOON_S_ASYNC_SCAN;
1113        stats.update_with_stat(&stat).unwrap();
1114        assert_eq!(stats.async_scan, Some(1));
1115        stat.tag = VIRTIO_BALLOON_S_DIRECT_SCAN;
1116        stats.update_with_stat(&stat).unwrap();
1117        assert_eq!(stats.direct_scan, Some(1));
1118        stat.tag = VIRTIO_BALLOON_S_ASYNC_RECLAIM;
1119        stats.update_with_stat(&stat).unwrap();
1120        assert_eq!(stats.async_reclaim, Some(1));
1121        stat.tag = VIRTIO_BALLOON_S_DIRECT_RECLAIM;
1122        stats.update_with_stat(&stat).unwrap();
1123        assert_eq!(stats.direct_reclaim, Some(1));
1124    }
1125
1126    #[test]
1127    fn test_virtio_features() {
1128        // Test all feature combinations.
1129        let combinations = iproduct!(
1130            &[true, false], // Reporitng
1131            &[true, false], // Hinting
1132            &[true, false], // Deflate
1133            &[0, 1]         // Interval
1134        );
1135
1136        for (reporting, hinting, deflate_on_oom, stats_interval) in combinations {
1137            let mut balloon =
1138                Balloon::new(0, *deflate_on_oom, *stats_interval, *hinting, *reporting).unwrap();
1139            assert_eq!(balloon.device_type(), VIRTIO_ID_BALLOON);
1140
1141            let features: u64 = (1u64 << VIRTIO_F_VERSION_1)
1142                | (u64::from(*deflate_on_oom) << VIRTIO_BALLOON_F_DEFLATE_ON_OOM)
1143                | ((u64::from(*reporting)) << VIRTIO_BALLOON_F_FREE_PAGE_REPORTING)
1144                | ((u64::from(*hinting)) << VIRTIO_BALLOON_F_FREE_PAGE_HINTING)
1145                | ((u64::from(*stats_interval)) << VIRTIO_BALLOON_F_STATS_VQ);
1146
1147            assert_eq!(
1148                balloon.avail_features_by_page(0),
1149                (features & 0xFFFFFFFF) as u32
1150            );
1151            assert_eq!(balloon.avail_features_by_page(1), (features >> 32) as u32);
1152            for i in 2..10 {
1153                assert_eq!(balloon.avail_features_by_page(i), 0u32);
1154            }
1155
1156            for i in 0..10 {
1157                balloon.ack_features_by_page(i, u32::MAX);
1158            }
1159            // Only present features should be acknowledged.
1160            assert_eq!(balloon.acked_features, features);
1161        }
1162    }
1163
1164    #[test]
1165    fn test_virtio_read_config() {
1166        let balloon = Balloon::new(0x10, true, 0, false, false).unwrap();
1167
1168        let cfg = BalloonConfig {
1169            amount_mib: 16,
1170            deflate_on_oom: true,
1171            stats_polling_interval_s: 0,
1172            free_page_hinting: false,
1173            free_page_reporting: false,
1174        };
1175        assert_eq!(balloon.config(), cfg);
1176
1177        let mut actual_config_space = [0u8; BALLOON_CONFIG_SPACE_SIZE];
1178        balloon.read_config(0, &mut actual_config_space);
1179        // The first 4 bytes are num_pages, the last 4 bytes are actual_pages.
1180        // The config space is little endian.
1181        // 0x10 MB in the constructor corresponds to 0x1000 pages in the
1182        // config space.
1183        let expected_config_space: [u8; BALLOON_CONFIG_SPACE_SIZE] = [
1184            0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1185        ];
1186        assert_eq!(actual_config_space, expected_config_space);
1187
1188        // Invalid read.
1189        let expected_config_space: [u8; BALLOON_CONFIG_SPACE_SIZE] = [
1190            0xd, 0xe, 0xa, 0xd, 0xb, 0xe, 0xe, 0xf, 0x00, 0x00, 0x00, 0x00,
1191        ];
1192        actual_config_space = expected_config_space;
1193        balloon.read_config(
1194            BALLOON_CONFIG_SPACE_SIZE as u64 + 1,
1195            &mut actual_config_space,
1196        );
1197
1198        // Validate read failed (the config space was not updated).
1199        assert_eq!(actual_config_space, expected_config_space);
1200    }
1201
1202    #[test]
1203    fn test_virtio_write_config() {
1204        let mut balloon = Balloon::new(0, true, 0, false, false).unwrap();
1205
1206        let expected_config_space: [u8; BALLOON_CONFIG_SPACE_SIZE] = [
1207            0x00, 0x50, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1208        ];
1209        balloon.write_config(0, &expected_config_space);
1210
1211        let mut actual_config_space = [0u8; BALLOON_CONFIG_SPACE_SIZE];
1212        balloon.read_config(0, &mut actual_config_space);
1213        assert_eq!(actual_config_space, expected_config_space);
1214
1215        // Invalid write.
1216        let new_config_space = [
1217            0xd, 0xe, 0xa, 0xd, 0xb, 0xe, 0xe, 0xf, 0x00, 0x00, 0x00, 0x00,
1218        ];
1219        balloon.write_config(5, &new_config_space);
1220        // Make sure nothing got written.
1221        balloon.read_config(0, &mut actual_config_space);
1222        assert_eq!(actual_config_space, expected_config_space);
1223
1224        // Large offset that may cause an overflow.
1225        balloon.write_config(u64::MAX, &new_config_space);
1226        // Make sure nothing got written.
1227        balloon.read_config(0, &mut actual_config_space);
1228        assert_eq!(actual_config_space, expected_config_space);
1229    }
1230
1231    #[test]
1232    fn test_free_page_hinting_config() {
1233        let mut balloon = Balloon::new(0, true, 0, true, false).unwrap();
1234        let mem = default_mem();
1235        let interrupt = default_interrupt();
1236        let infq = VirtQueue::new(GuestAddress(0), &mem, 16);
1237        balloon.set_queue(INFLATE_INDEX, infq.create_queue());
1238        balloon.set_queue(DEFLATE_INDEX, infq.create_queue());
1239        balloon.set_queue(balloon.free_page_hinting_idx(), infq.create_queue());
1240        balloon.activate(mem.clone(), interrupt).unwrap();
1241
1242        let expected_config_space: [u8; BALLOON_CONFIG_SPACE_SIZE] = [
1243            0x00, 0x50, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1244        ];
1245        balloon.write_config(0, &expected_config_space);
1246
1247        let mut actual_config_space = [0u8; BALLOON_CONFIG_SPACE_SIZE];
1248        balloon.read_config(0, &mut actual_config_space);
1249        assert_eq!(actual_config_space, expected_config_space);
1250
1251        // We expect the cmd_id to be set to 2 now
1252        balloon.start_hinting(Default::default()).unwrap();
1253
1254        let expected_config_space: [u8; BALLOON_CONFIG_SPACE_SIZE] = [
1255            0x00, 0x50, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
1256        ];
1257        let mut actual_config_space = [0u8; BALLOON_CONFIG_SPACE_SIZE];
1258        balloon.read_config(0, &mut actual_config_space);
1259        assert_eq!(actual_config_space, expected_config_space);
1260
1261        // We expect the cmd_id to be set to 1
1262        balloon.stop_hinting().unwrap();
1263
1264        let expected_config_space: [u8; BALLOON_CONFIG_SPACE_SIZE] = [
1265            0x00, 0x50, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
1266        ];
1267        let mut actual_config_space = [0u8; BALLOON_CONFIG_SPACE_SIZE];
1268        balloon.read_config(0, &mut actual_config_space);
1269        assert_eq!(actual_config_space, expected_config_space);
1270
1271        // We expect the cmd_id to be bumped up to 3 now
1272        balloon.start_hinting(Default::default()).unwrap();
1273
1274        let expected_config_space: [u8; BALLOON_CONFIG_SPACE_SIZE] = [
1275            0x00, 0x50, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
1276        ];
1277        let mut actual_config_space = [0u8; BALLOON_CONFIG_SPACE_SIZE];
1278        balloon.read_config(0, &mut actual_config_space);
1279        assert_eq!(actual_config_space, expected_config_space);
1280    }
1281
1282    #[test]
1283    fn test_invalid_request() {
1284        let mut balloon = Balloon::new(0, true, 0, false, false).unwrap();
1285        let mem = default_mem();
1286        let interrupt = default_interrupt();
1287        // Only initialize the inflate queue to demonstrate invalid request handling.
1288        let infq = VirtQueue::new(GuestAddress(0), &mem, 16);
1289        balloon.set_queue(INFLATE_INDEX, infq.create_queue());
1290        balloon.set_queue(DEFLATE_INDEX, infq.create_queue());
1291        balloon.activate(mem.clone(), interrupt).unwrap();
1292
1293        // Fill the second page with non-zero bytes.
1294        for i in 0..0x1000 {
1295            mem.write_obj::<u8>(1, GuestAddress((1 << 12) + i)).unwrap();
1296        }
1297
1298        // Will write the page frame number of the affected frame at this
1299        // arbitrary address in memory.
1300        let page_addr = 0x10;
1301
1302        // Invalid case: the descriptor is write-only.
1303        {
1304            mem.write_obj::<u32>(0x1, GuestAddress(page_addr)).unwrap();
1305            set_request(
1306                &infq,
1307                0,
1308                page_addr,
1309                SIZE_OF_U32.try_into().unwrap(),
1310                VIRTQ_DESC_F_NEXT | VIRTQ_DESC_F_WRITE,
1311            );
1312
1313            invoke_handler_for_queue_event(&mut balloon, INFLATE_INDEX);
1314            check_request_completion(&infq, 0);
1315
1316            // Check that the page was not zeroed.
1317            for i in 0..0x1000 {
1318                assert_eq!(mem.read_obj::<u8>(GuestAddress((1 << 12) + i)).unwrap(), 1);
1319            }
1320        }
1321
1322        // Invalid case: descriptor len is not a multiple of 'SIZE_OF_U32'.
1323        {
1324            mem.write_obj::<u32>(0x1, GuestAddress(page_addr)).unwrap();
1325            set_request(
1326                &infq,
1327                1,
1328                page_addr,
1329                u32::try_from(SIZE_OF_U32).unwrap() + 1,
1330                VIRTQ_DESC_F_NEXT,
1331            );
1332
1333            invoke_handler_for_queue_event(&mut balloon, INFLATE_INDEX);
1334            check_request_completion(&infq, 1);
1335
1336            // Check that the page was not zeroed.
1337            for i in 0..0x1000 {
1338                assert_eq!(mem.read_obj::<u8>(GuestAddress((1 << 12) + i)).unwrap(), 1);
1339            }
1340        }
1341    }
1342
1343    #[test]
1344    fn test_inflate() {
1345        let mut balloon = Balloon::new(0, true, 0, false, false).unwrap();
1346        let mem = default_mem();
1347        let interrupt = default_interrupt();
1348        let infq = VirtQueue::new(GuestAddress(0), &mem, 16);
1349        balloon.set_queue(INFLATE_INDEX, infq.create_queue());
1350        balloon.set_queue(DEFLATE_INDEX, infq.create_queue());
1351        balloon.activate(mem.clone(), interrupt).unwrap();
1352
1353        // Fill the third page with non-zero bytes.
1354        for i in 0..0x1000 {
1355            mem.write_obj::<u8>(1, GuestAddress((1 << 12) + i)).unwrap();
1356        }
1357
1358        // Will write the page frame number of the affected frame at this
1359        // arbitrary address in memory.
1360        let page_addr = 0x10;
1361
1362        // Error case: the request is well-formed, but we forgot
1363        // to trigger the inflate event queue.
1364        {
1365            mem.write_obj::<u32>(0x1, GuestAddress(page_addr)).unwrap();
1366            set_request(
1367                &infq,
1368                0,
1369                page_addr,
1370                SIZE_OF_U32.try_into().unwrap(),
1371                VIRTQ_DESC_F_NEXT,
1372            );
1373
1374            check_metric_after_block!(
1375                METRICS.event_fails,
1376                1,
1377                balloon
1378                    .process_inflate_queue_event()
1379                    .unwrap_or_else(report_balloon_event_fail)
1380            );
1381            // Verify that nothing got processed.
1382            assert_eq!(infq.used.idx.get(), 0);
1383
1384            // Check that the page was not zeroed.
1385            for i in 0..0x1000 {
1386                assert_eq!(mem.read_obj::<u8>(GuestAddress((1 << 12) + i)).unwrap(), 1);
1387            }
1388        }
1389
1390        // Test the happy case.
1391        {
1392            mem.write_obj::<u32>(0x1, GuestAddress(page_addr)).unwrap();
1393            set_request(
1394                &infq,
1395                0,
1396                page_addr,
1397                SIZE_OF_U32.try_into().unwrap(),
1398                VIRTQ_DESC_F_NEXT,
1399            );
1400
1401            check_metric_after_block!(
1402                METRICS.inflate_count,
1403                1,
1404                invoke_handler_for_queue_event(&mut balloon, INFLATE_INDEX)
1405            );
1406            check_request_completion(&infq, 0);
1407
1408            // Check that the page was zeroed.
1409            for i in 0..0x1000 {
1410                assert_eq!(mem.read_obj::<u8>(GuestAddress((1 << 12) + i)).unwrap(), 0);
1411            }
1412        }
1413    }
1414
1415    #[test]
1416    fn test_deflate() {
1417        let mut balloon = Balloon::new(0, true, 0, false, false).unwrap();
1418        let mem = default_mem();
1419        let interrupt = default_interrupt();
1420        let defq = VirtQueue::new(GuestAddress(0), &mem, 16);
1421        balloon.set_queue(INFLATE_INDEX, defq.create_queue());
1422        balloon.set_queue(DEFLATE_INDEX, defq.create_queue());
1423        balloon.activate(mem.clone(), interrupt).unwrap();
1424
1425        let page_addr = 0x10;
1426
1427        // Error case: forgot to trigger deflate event queue.
1428        {
1429            set_request(
1430                &defq,
1431                0,
1432                page_addr,
1433                SIZE_OF_U32.try_into().unwrap(),
1434                VIRTQ_DESC_F_NEXT,
1435            );
1436            check_metric_after_block!(
1437                METRICS.event_fails,
1438                1,
1439                balloon
1440                    .process_deflate_queue_event()
1441                    .unwrap_or_else(report_balloon_event_fail)
1442            );
1443            // Verify that nothing got processed.
1444            assert_eq!(defq.used.idx.get(), 0);
1445        }
1446
1447        // Happy case.
1448        {
1449            set_request(
1450                &defq,
1451                1,
1452                page_addr,
1453                SIZE_OF_U32.try_into().unwrap(),
1454                VIRTQ_DESC_F_NEXT,
1455            );
1456            check_metric_after_block!(
1457                METRICS.deflate_count,
1458                1,
1459                invoke_handler_for_queue_event(&mut balloon, DEFLATE_INDEX)
1460            );
1461            check_request_completion(&defq, 1);
1462        }
1463    }
1464
1465    #[test]
1466    fn test_stats() {
1467        let mut balloon = Balloon::new(0, true, 1, false, false).unwrap();
1468        let mem = default_mem();
1469        let interrupt = default_interrupt();
1470        let statsq = VirtQueue::new(GuestAddress(0), &mem, 16);
1471        balloon.set_queue(INFLATE_INDEX, statsq.create_queue());
1472        balloon.set_queue(DEFLATE_INDEX, statsq.create_queue());
1473        balloon.set_queue(STATS_INDEX, statsq.create_queue());
1474        balloon.activate(mem.clone(), interrupt).unwrap();
1475
1476        let page_addr = 0x100;
1477
1478        // Error case: forgot to trigger stats event queue.
1479        {
1480            set_request(
1481                &statsq,
1482                0,
1483                0x1000,
1484                SIZE_OF_STAT.try_into().unwrap(),
1485                VIRTQ_DESC_F_NEXT,
1486            );
1487            check_metric_after_block!(
1488                METRICS.event_fails,
1489                1,
1490                balloon
1491                    .process_stats_queue_event()
1492                    .unwrap_or_else(report_balloon_event_fail)
1493            );
1494            // Verify that nothing got processed.
1495            assert_eq!(statsq.used.idx.get(), 0);
1496        }
1497
1498        // Happy case.
1499        {
1500            let swap_out_stat = BalloonStat {
1501                tag: VIRTIO_BALLOON_S_SWAP_OUT,
1502                val: 0x1,
1503            };
1504            let mem_free_stat = BalloonStat {
1505                tag: VIRTIO_BALLOON_S_MEMFREE,
1506                val: 0x5678,
1507            };
1508
1509            // Write the stats in memory.
1510            mem.write_obj::<BalloonStat>(swap_out_stat, GuestAddress(page_addr))
1511                .unwrap();
1512            mem.write_obj::<BalloonStat>(
1513                mem_free_stat,
1514                GuestAddress(page_addr + SIZE_OF_STAT as u64),
1515            )
1516            .unwrap();
1517
1518            set_request(
1519                &statsq,
1520                0,
1521                page_addr,
1522                2 * u32::try_from(SIZE_OF_STAT).unwrap(),
1523                VIRTQ_DESC_F_NEXT,
1524            );
1525            check_metric_after_block!(METRICS.stats_updates_count, 1, {
1526                // Trigger the queue event.
1527                balloon.queue_events()[STATS_INDEX].write(1).unwrap();
1528                balloon.process_stats_queue_event().unwrap();
1529                // Don't check for completion yet.
1530            });
1531
1532            let stats = balloon.latest_stats().unwrap();
1533            let expected_stats = BalloonStats {
1534                swap_out: Some(0x1),
1535                free_memory: Some(0x5678),
1536                ..BalloonStats::default()
1537            };
1538            assert_eq!(stats, expected_stats);
1539
1540            // Wait for the timer to expire, although as it is non-blocking
1541            // we could just process the timer event and it would not
1542            // return an error.
1543            std::thread::sleep(Duration::from_secs(1));
1544            check_metric_after_block!(METRICS.event_fails, 0, {
1545                // Trigger the timer event, which consumes the stats
1546                // descriptor index and signals the used queue.
1547                assert!(balloon.stats_desc_index.is_some());
1548                balloon.process_stats_timer_event().unwrap();
1549                assert!(balloon.stats_desc_index.is_none());
1550                assert!(balloon.interrupt_trigger().has_pending_interrupt(
1551                    VirtioInterruptType::Queue(STATS_INDEX.try_into().unwrap())
1552                ));
1553            });
1554        }
1555    }
1556
1557    #[test]
1558    fn test_process_reporting() {
1559        let mem = create_virtio_mem();
1560        let mut th =
1561            VirtioTestHelper::<Balloon>::new(&mem, Balloon::new(0, true, 0, false, true).unwrap());
1562
1563        th.activate_device(&mem);
1564
1565        let page_size = host_page_size() as u64;
1566
1567        // This has to be u32 for the scatter gather
1568        #[allow(clippy::cast_possible_truncation)]
1569        let page_size_chain = page_size as u32;
1570        let reporting_idx = th.device().free_page_reporting_idx();
1571
1572        let safe_addr = align_up(th.data_address(), page_size);
1573
1574        th.add_scatter_gather(reporting_idx, 0, &[(0, safe_addr, page_size_chain, 0)]);
1575        check_metric_after_block!(
1576            METRICS.free_page_report_freed,
1577            page_size,
1578            invoke_handler_for_queue_event(&mut th.device(), reporting_idx)
1579        );
1580
1581        // Test with multiple items
1582        th.add_scatter_gather(
1583            reporting_idx,
1584            0,
1585            &[
1586                (0, safe_addr, page_size_chain, 0),
1587                (1, safe_addr + page_size, page_size_chain, 0),
1588                (2, safe_addr + (page_size * 2), page_size_chain, 0),
1589            ],
1590        );
1591
1592        check_metric_after_block!(
1593            METRICS.free_page_report_freed,
1594            page_size * 3,
1595            invoke_handler_for_queue_event(&mut th.device(), reporting_idx)
1596        );
1597
1598        // Test with unaligned length
1599        th.add_scatter_gather(reporting_idx, 0, &[(1, safe_addr + 1, page_size_chain, 0)]);
1600
1601        check_metric_after_block!(
1602            METRICS.free_page_report_fails,
1603            1,
1604            invoke_handler_for_queue_event(&mut th.device(), reporting_idx)
1605        );
1606    }
1607
1608    struct HintingTestHelper<'a> {
1609        mem: &'a GuestMemoryMmap,
1610        th: VirtioTestHelper<'a, Balloon>,
1611        page_size: u64,
1612        page_size_chain: u32,
1613        hinting_idx: usize,
1614        safe_addr: u64,
1615    }
1616
1617    impl<'a> HintingTestHelper<'a> {
1618        fn new(mem: &'a GuestMemoryMmap) -> Self {
1619            let mut th = VirtioTestHelper::<Balloon>::new(
1620                mem,
1621                Balloon::new(0, true, 0, true, false).unwrap(),
1622            );
1623            th.activate_device(mem);
1624
1625            let page_size = host_page_size() as u64;
1626            let hinting_idx = th.device().free_page_hinting_idx();
1627            let safe_addr = align_up(th.data_address(), page_size);
1628
1629            // Ack the config set on start
1630            th.device()
1631                .interrupt_trigger()
1632                .ack_interrupt(VirtioInterruptType::Config);
1633
1634            Self {
1635                mem,
1636                th,
1637                page_size,
1638                hinting_idx,
1639                // This has to be u32 for the scatter gather
1640                #[allow(clippy::cast_possible_truncation)]
1641                page_size_chain: page_size as u32,
1642                safe_addr,
1643            }
1644        }
1645
1646        fn start_hinting(&mut self, cmd: Option<StartHintingCmd>) {
1647            let cmd = cmd.unwrap_or_default();
1648            self.th.device().start_hinting(cmd).unwrap();
1649            assert!(
1650                self.th
1651                    .device()
1652                    .interrupt_trigger()
1653                    .has_pending_interrupt(VirtioInterruptType::Config)
1654            );
1655            self.th
1656                .device()
1657                .interrupt_trigger()
1658                .ack_interrupt(VirtioInterruptType::Config);
1659        }
1660
1661        fn send_stop(&mut self, cmd: Option<u32>) {
1662            let cmd = cmd.unwrap_or(FREE_PAGE_HINT_STOP);
1663
1664            self.mem
1665                .write_obj(cmd, GuestAddress::new(self.safe_addr))
1666                .unwrap();
1667            self.th.add_scatter_gather(
1668                self.hinting_idx,
1669                0,
1670                &[
1671                    (0, self.safe_addr, 4, VIRTQ_DESC_F_WRITE),
1672                    (
1673                        1,
1674                        self.safe_addr + self.page_size,
1675                        self.page_size_chain,
1676                        VIRTQ_DESC_F_WRITE,
1677                    ),
1678                ],
1679            );
1680            check_metric_after_block!(
1681                METRICS.free_page_hint_freed,
1682                0,
1683                self.th.device().process_free_page_hinting_queue()
1684            );
1685            self.th
1686                .device()
1687                .interrupt_trigger()
1688                .ack_interrupt(VirtioInterruptType::Queue(
1689                    self.hinting_idx.try_into().unwrap(),
1690                ));
1691            self.th
1692                .device()
1693                .interrupt_trigger()
1694                .ack_interrupt(VirtioInterruptType::Config);
1695        }
1696
1697        fn test_hinting(&mut self, cmd: Option<u32>, expected: u64) {
1698            let payload = match cmd {
1699                Some(c) => {
1700                    self.mem
1701                        .write_obj(c, GuestAddress::new(self.safe_addr))
1702                        .unwrap();
1703                    vec![
1704                        (0, self.safe_addr, 4, VIRTQ_DESC_F_WRITE),
1705                        (
1706                            1,
1707                            self.safe_addr + self.page_size,
1708                            self.page_size_chain,
1709                            VIRTQ_DESC_F_WRITE,
1710                        ),
1711                    ]
1712                }
1713                None => {
1714                    vec![(
1715                        0,
1716                        self.safe_addr + self.page_size,
1717                        self.page_size_chain,
1718                        VIRTQ_DESC_F_WRITE,
1719                    )]
1720                }
1721            };
1722            self.th.add_scatter_gather(self.hinting_idx, 0, &payload);
1723            check_metric_after_block!(
1724                METRICS.free_page_hint_freed,
1725                expected,
1726                invoke_handler_for_queue_event(&mut self.th.device(), self.hinting_idx)
1727            );
1728        }
1729    }
1730
1731    #[test]
1732    fn test_hinting_no_cmd_set() {
1733        let mem = create_virtio_mem();
1734        let mut ht = HintingTestHelper::new(&mem);
1735
1736        // Report a page before a cmd_id has even been negotiated
1737        ht.test_hinting(Some(2), 0);
1738    }
1739
1740    #[test]
1741    fn test_hinting_normal_path() {
1742        let mem = create_virtio_mem();
1743        let mut ht = HintingTestHelper::new(&mem);
1744
1745        // Test the good case
1746        ht.start_hinting(None);
1747
1748        let host_cmd = ht.th.device().get_hinting_status().unwrap().host_cmd;
1749
1750        // Ack the start of the hinting run and send a single page
1751        ht.test_hinting(Some(host_cmd), ht.page_size);
1752    }
1753
1754    #[test]
1755    fn test_hinting_invalid_cmd() {
1756        let mem = create_virtio_mem();
1757        let mut ht = HintingTestHelper::new(&mem);
1758
1759        // Test the good case
1760        ht.start_hinting(None);
1761        let host_cmd = ht.th.device().get_hinting_status().unwrap().host_cmd;
1762
1763        // Report pages for an invalid cmd
1764        ht.test_hinting(Some(host_cmd + 1), 0);
1765
1766        // If correct cmd is again used continue again
1767        ht.test_hinting(Some(host_cmd), ht.page_size);
1768    }
1769
1770    #[test]
1771    fn test_hinting_stale_inflight_requests() {
1772        let mem = create_virtio_mem();
1773        let mut ht = HintingTestHelper::new(&mem);
1774
1775        // Test the good case
1776        ht.start_hinting(None);
1777        let mut host_cmd = ht.th.device().get_hinting_status().unwrap().host_cmd;
1778
1779        ht.test_hinting(Some(host_cmd), ht.page_size);
1780
1781        // Trigger another hinting run this will bump the cmd id
1782        // so we should ignore any inflight requests
1783        ht.start_hinting(None);
1784        ht.test_hinting(None, 0);
1785
1786        // Update to our new host cmd and check this now works
1787        host_cmd = ht.th.device().get_hinting_status().unwrap().host_cmd;
1788        ht.test_hinting(Some(host_cmd), ht.page_size);
1789        ht.test_hinting(None, ht.page_size);
1790    }
1791
1792    #[test]
1793    fn test_hinting_stale_post_stop() {
1794        let mem = create_virtio_mem();
1795        let mut ht = HintingTestHelper::new(&mem);
1796
1797        // Test the good case
1798        ht.start_hinting(None);
1799        let mut host_cmd = ht.th.device().get_hinting_status().unwrap().host_cmd;
1800
1801        // Simulate the driver finishing a run. Any reported values after
1802        // should be ignored
1803        ht.send_stop(None);
1804        // Test we handle invalid cmd from driver
1805        ht.send_stop(Some(FREE_PAGE_HINT_DONE));
1806        ht.test_hinting(None, 0);
1807
1808        // As we had auto ack on finish the host cmd should be set to done
1809        host_cmd = ht.th.device().get_hinting_status().unwrap().host_cmd;
1810        assert_eq!(host_cmd, FREE_PAGE_HINT_DONE);
1811    }
1812
1813    #[test]
1814    fn test_hinting_no_ack_on_stop() {
1815        let mem = create_virtio_mem();
1816        let mut ht = HintingTestHelper::new(&mem);
1817
1818        // Test the good case
1819        ht.start_hinting(None);
1820        let mut host_cmd = ht.th.device().get_hinting_status().unwrap().host_cmd;
1821
1822        // Test no ack on stop behaviour
1823        ht.start_hinting(Some(StartHintingCmd {
1824            acknowledge_on_stop: false,
1825        }));
1826
1827        host_cmd = ht.th.device().get_hinting_status().unwrap().host_cmd;
1828        ht.test_hinting(Some(host_cmd), ht.page_size);
1829        ht.test_hinting(None, ht.page_size);
1830
1831        ht.send_stop(None);
1832        let new_host_cmd = ht.th.device().get_hinting_status().unwrap().host_cmd;
1833        assert_eq!(host_cmd, new_host_cmd);
1834    }
1835
1836    #[test]
1837    fn test_hinting_misaligned_value() {
1838        let mem = create_virtio_mem();
1839        let mut ht = HintingTestHelper::new(&mem);
1840
1841        // Test the good case
1842        ht.start_hinting(None);
1843        let mut host_cmd = ht.th.device().get_hinting_status().unwrap().host_cmd;
1844
1845        ht.test_hinting(Some(host_cmd), ht.page_size);
1846        ht.test_hinting(None, ht.page_size);
1847
1848        ht.th.add_scatter_gather(
1849            ht.hinting_idx,
1850            0,
1851            &[(0, ht.safe_addr + ht.page_size + 1, ht.page_size_chain, 0)],
1852        );
1853
1854        check_metric_after_block!(
1855            METRICS.free_page_hint_fails,
1856            1,
1857            ht.th.device().process_free_page_hinting_queue().unwrap()
1858        );
1859    }
1860
1861    #[test]
1862    fn test_process_balloon_queues() {
1863        let mut balloon = Balloon::new(0x10, true, 0, true, true).unwrap();
1864        let mem = default_mem();
1865        let interrupt = default_interrupt();
1866        let infq = VirtQueue::new(GuestAddress(0), &mem, 16);
1867        let defq = VirtQueue::new(GuestAddress(0), &mem, 16);
1868        let hintq = VirtQueue::new(GuestAddress(0), &mem, 16);
1869        let reportq = VirtQueue::new(GuestAddress(0), &mem, 16);
1870
1871        balloon.set_queue(INFLATE_INDEX, infq.create_queue());
1872        balloon.set_queue(DEFLATE_INDEX, defq.create_queue());
1873        balloon.set_queue(balloon.free_page_hinting_idx(), hintq.create_queue());
1874        balloon.set_queue(balloon.free_page_reporting_idx(), reportq.create_queue());
1875
1876        balloon.activate(mem, interrupt).unwrap();
1877        balloon.process_virtio_queues().unwrap();
1878    }
1879
1880    #[test]
1881    fn test_update_stats_interval() {
1882        let mut balloon = Balloon::new(0, true, 0, false, false).unwrap();
1883        let mem = default_mem();
1884        let q = VirtQueue::new(GuestAddress(0), &mem, 16);
1885        balloon.set_queue(INFLATE_INDEX, q.create_queue());
1886        balloon.set_queue(DEFLATE_INDEX, q.create_queue());
1887        let interrupt = default_interrupt();
1888        balloon.activate(mem, interrupt).unwrap();
1889        assert_eq!(
1890            format!("{:?}", balloon.update_stats_polling_interval(1)),
1891            "Err(StatisticsStateChange)"
1892        );
1893        balloon.update_stats_polling_interval(0).unwrap();
1894
1895        let mut balloon = Balloon::new(0, true, 1, false, false).unwrap();
1896        let mem = default_mem();
1897        let q = VirtQueue::new(GuestAddress(0), &mem, 16);
1898        balloon.set_queue(INFLATE_INDEX, q.create_queue());
1899        balloon.set_queue(DEFLATE_INDEX, q.create_queue());
1900        balloon.set_queue(STATS_INDEX, q.create_queue());
1901        let interrupt = default_interrupt();
1902        balloon.activate(mem, interrupt).unwrap();
1903        assert_eq!(
1904            format!("{:?}", balloon.update_stats_polling_interval(0)),
1905            "Err(StatisticsStateChange)"
1906        );
1907        balloon.update_stats_polling_interval(1).unwrap();
1908        balloon.update_stats_polling_interval(2).unwrap();
1909    }
1910
1911    #[test]
1912    fn test_cannot_update_inactive_device() {
1913        let mut balloon = Balloon::new(0, true, 0, false, false).unwrap();
1914        // Assert that we can't update an inactive device.
1915        balloon.update_size(1).unwrap_err();
1916        balloon.start_hinting(Default::default()).unwrap_err();
1917        balloon.get_hinting_status().unwrap_err();
1918        balloon.stop_hinting().unwrap_err();
1919    }
1920
1921    #[test]
1922    fn test_num_pages() {
1923        let mut balloon = Balloon::new(0, true, 0, false, false).unwrap();
1924        // Switch the state to active.
1925        balloon.device_state = DeviceState::Activated(ActiveState {
1926            mem: single_region_mem(32 << 20),
1927            interrupt: default_interrupt(),
1928        });
1929
1930        assert_eq!(balloon.num_pages(), 0);
1931        assert_eq!(balloon.actual_pages(), 0);
1932
1933        // Update fields through the API.
1934        balloon.update_actual_pages(0x1234);
1935        balloon.update_num_pages(0x100);
1936        assert_eq!(balloon.num_pages(), 0x100);
1937        balloon.update_size(16).unwrap();
1938
1939        let mut actual_config = vec![0; BALLOON_CONFIG_SPACE_SIZE];
1940        balloon.read_config(0, &mut actual_config);
1941        assert_eq!(
1942            actual_config,
1943            vec![0x0, 0x10, 0x0, 0x0, 0x34, 0x12, 0, 0, 0, 0, 0, 0]
1944        );
1945        assert_eq!(balloon.num_pages(), 0x1000);
1946        assert_eq!(balloon.actual_pages(), 0x1234);
1947        assert_eq!(balloon.size_mb(), 16);
1948
1949        // Update fields through the config space.
1950        let expected_config = vec![0x44, 0x33, 0x22, 0x11, 0x78, 0x56, 0x34, 0x12, 0, 0, 0, 0];
1951        balloon.write_config(0, &expected_config);
1952        assert_eq!(balloon.num_pages(), 0x1122_3344);
1953        assert_eq!(balloon.actual_pages(), 0x1234_5678);
1954    }
1955}