vmm/rate_limiter/
persist.rs

1// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Defines the structures needed for saving/restoring a RateLimiter.
5
6use serde::{Deserialize, Serialize};
7
8use super::*;
9use crate::snapshot::Persist;
10
11/// State for saving a TokenBucket.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct TokenBucketState {
14    size: u64,
15    one_time_burst: u64,
16    refill_time: u64,
17    budget: u64,
18    elapsed_ns: u64,
19}
20
21impl Persist<'_> for TokenBucket {
22    type State = TokenBucketState;
23    type ConstructorArgs = ();
24    type Error = io::Error;
25
26    fn save(&self) -> Self::State {
27        TokenBucketState {
28            size: self.size,
29            one_time_burst: self.one_time_burst,
30            refill_time: self.refill_time,
31            budget: self.budget,
32            // This should be safe for a duration of about 584 years.
33            elapsed_ns: u64::try_from(self.last_update.elapsed().as_nanos()).unwrap(),
34        }
35    }
36
37    fn restore(_: Self::ConstructorArgs, state: &Self::State) -> Result<Self, Self::Error> {
38        let now = Instant::now();
39        let last_update = now
40            .checked_sub(Duration::from_nanos(state.elapsed_ns))
41            .unwrap_or(now);
42
43        let mut token_bucket =
44            TokenBucket::new(state.size, state.one_time_burst, state.refill_time)
45                .ok_or_else(|| io::Error::from(io::ErrorKind::InvalidInput))?;
46
47        token_bucket.budget = state.budget;
48        token_bucket.last_update = last_update;
49
50        Ok(token_bucket)
51    }
52}
53
54/// State for saving a RateLimiter.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct RateLimiterState {
57    ops: Option<TokenBucketState>,
58    bandwidth: Option<TokenBucketState>,
59}
60
61impl Persist<'_> for RateLimiter {
62    type State = RateLimiterState;
63    type ConstructorArgs = ();
64    type Error = io::Error;
65
66    fn save(&self) -> Self::State {
67        RateLimiterState {
68            ops: self.ops.as_ref().map(|ops| ops.save()),
69            bandwidth: self.bandwidth.as_ref().map(|bw| bw.save()),
70        }
71    }
72
73    fn restore(_: Self::ConstructorArgs, state: &Self::State) -> Result<Self, Self::Error> {
74        let rate_limiter = RateLimiter {
75            ops: if let Some(ops) = state.ops.as_ref() {
76                Some(TokenBucket::restore((), ops)?)
77            } else {
78                None
79            },
80            bandwidth: if let Some(bw) = state.bandwidth.as_ref() {
81                Some(TokenBucket::restore((), bw)?)
82            } else {
83                None
84            },
85            timer_fd: TimerFd::new_custom(ClockId::Monotonic, true, true)?,
86            timer_active: false,
87        };
88
89        Ok(rate_limiter)
90    }
91}
92
93#[cfg(test)]
94mod tests {
95
96    use super::*;
97    use crate::snapshot::Snapshot;
98
99    #[test]
100    fn test_token_bucket_persistence() {
101        let mut tb = TokenBucket::new(1000, 2000, 3000).unwrap();
102
103        // Check that TokenBucket restores correctly if untouched.
104        let restored_tb = TokenBucket::restore((), &tb.save()).unwrap();
105        assert!(tb.partial_eq(&restored_tb));
106
107        // Check that TokenBucket restores correctly after partially consuming tokens.
108        tb.reduce(100);
109        let restored_tb = TokenBucket::restore((), &tb.save()).unwrap();
110        assert!(tb.partial_eq(&restored_tb));
111
112        // Check that TokenBucket restores correctly after replenishing tokens.
113        tb.force_replenish(100);
114        let restored_tb = TokenBucket::restore((), &tb.save()).unwrap();
115        assert!(tb.partial_eq(&restored_tb));
116
117        // Test serialization.
118        let mut mem = vec![0; 4096];
119        Snapshot::new(tb.save())
120            .save(&mut mem.as_mut_slice())
121            .unwrap();
122
123        let restored_tb = TokenBucket::restore(
124            (),
125            &Snapshot::load_without_crc_check(mem.as_slice())
126                .unwrap()
127                .data,
128        )
129        .unwrap();
130        assert!(tb.partial_eq(&restored_tb));
131    }
132
133    #[test]
134    fn test_rate_limiter_persistence() {
135        let refill_time = 100_000;
136        let mut rate_limiter = RateLimiter::new(100, 0, refill_time, 10, 0, refill_time).unwrap();
137
138        // Check that RateLimiter restores correctly if untouched.
139        let restored_rate_limiter =
140            RateLimiter::restore((), &rate_limiter.save()).expect("Unable to restore rate limiter");
141
142        assert!(
143            rate_limiter
144                .ops()
145                .unwrap()
146                .partial_eq(restored_rate_limiter.ops().unwrap())
147        );
148        assert!(
149            rate_limiter
150                .bandwidth()
151                .unwrap()
152                .partial_eq(restored_rate_limiter.bandwidth().unwrap())
153        );
154        assert_eq!(
155            restored_rate_limiter.timer_fd.get_state(),
156            TimerState::Disarmed
157        );
158
159        // Check that RateLimiter restores correctly after partially consuming tokens.
160        rate_limiter.consume(10, TokenType::Bytes);
161        rate_limiter.consume(10, TokenType::Ops);
162        let restored_rate_limiter =
163            RateLimiter::restore((), &rate_limiter.save()).expect("Unable to restore rate limiter");
164
165        assert!(
166            rate_limiter
167                .ops()
168                .unwrap()
169                .partial_eq(restored_rate_limiter.ops().unwrap())
170        );
171        assert!(
172            rate_limiter
173                .bandwidth()
174                .unwrap()
175                .partial_eq(restored_rate_limiter.bandwidth().unwrap())
176        );
177        assert_eq!(
178            restored_rate_limiter.timer_fd.get_state(),
179            TimerState::Disarmed
180        );
181
182        // Check that RateLimiter restores correctly after totally consuming tokens.
183        rate_limiter.consume(1000, TokenType::Bytes);
184        let restored_rate_limiter =
185            RateLimiter::restore((), &rate_limiter.save()).expect("Unable to restore rate limiter");
186
187        assert!(
188            rate_limiter
189                .ops()
190                .unwrap()
191                .partial_eq(restored_rate_limiter.ops().unwrap())
192        );
193        assert!(
194            rate_limiter
195                .bandwidth()
196                .unwrap()
197                .partial_eq(restored_rate_limiter.bandwidth().unwrap())
198        );
199
200        // Test serialization.
201        let mut mem = vec![0; 4096];
202        Snapshot::new(rate_limiter.save())
203            .save(&mut mem.as_mut_slice())
204            .unwrap();
205        let restored_rate_limiter = RateLimiter::restore(
206            (),
207            &Snapshot::load_without_crc_check(mem.as_slice())
208                .unwrap()
209                .data,
210        )
211        .unwrap();
212
213        assert!(
214            rate_limiter
215                .ops()
216                .unwrap()
217                .partial_eq(restored_rate_limiter.ops().unwrap())
218        );
219        assert!(
220            rate_limiter
221                .bandwidth()
222                .unwrap()
223                .partial_eq(restored_rate_limiter.bandwidth().unwrap())
224        );
225    }
226}