vmm/rate_limiter/
persist.rs1use serde::{Deserialize, Serialize};
7
8use super::*;
9use crate::snapshot::Persist;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct TokenBucketState {
14 size: u64,
15 one_time_burst: u64,
16 refill_time: u64,
17 budget: u64,
18 elapsed_ns: u64,
19}
20
21impl Persist<'_> for TokenBucket {
22 type State = TokenBucketState;
23 type ConstructorArgs = ();
24 type Error = io::Error;
25
26 fn save(&self) -> Self::State {
27 TokenBucketState {
28 size: self.size,
29 one_time_burst: self.one_time_burst,
30 refill_time: self.refill_time,
31 budget: self.budget,
32 elapsed_ns: u64::try_from(self.last_update.elapsed().as_nanos()).unwrap(),
34 }
35 }
36
37 fn restore(_: Self::ConstructorArgs, state: &Self::State) -> Result<Self, Self::Error> {
38 let now = Instant::now();
39 let last_update = now
40 .checked_sub(Duration::from_nanos(state.elapsed_ns))
41 .unwrap_or(now);
42
43 let mut token_bucket =
44 TokenBucket::new(state.size, state.one_time_burst, state.refill_time)
45 .ok_or_else(|| io::Error::from(io::ErrorKind::InvalidInput))?;
46
47 token_bucket.budget = state.budget;
48 token_bucket.last_update = last_update;
49
50 Ok(token_bucket)
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct RateLimiterState {
57 ops: Option<TokenBucketState>,
58 bandwidth: Option<TokenBucketState>,
59}
60
61impl Persist<'_> for RateLimiter {
62 type State = RateLimiterState;
63 type ConstructorArgs = ();
64 type Error = io::Error;
65
66 fn save(&self) -> Self::State {
67 RateLimiterState {
68 ops: self.ops.as_ref().map(|ops| ops.save()),
69 bandwidth: self.bandwidth.as_ref().map(|bw| bw.save()),
70 }
71 }
72
73 fn restore(_: Self::ConstructorArgs, state: &Self::State) -> Result<Self, Self::Error> {
74 let rate_limiter = RateLimiter {
75 ops: if let Some(ops) = state.ops.as_ref() {
76 Some(TokenBucket::restore((), ops)?)
77 } else {
78 None
79 },
80 bandwidth: if let Some(bw) = state.bandwidth.as_ref() {
81 Some(TokenBucket::restore((), bw)?)
82 } else {
83 None
84 },
85 timer_fd: TimerFd::new_custom(ClockId::Monotonic, true, true)?,
86 timer_active: false,
87 };
88
89 Ok(rate_limiter)
90 }
91}
92
93#[cfg(test)]
94mod tests {
95
96 use super::*;
97 use crate::snapshot::Snapshot;
98
99 #[test]
100 fn test_token_bucket_persistence() {
101 let mut tb = TokenBucket::new(1000, 2000, 3000).unwrap();
102
103 let restored_tb = TokenBucket::restore((), &tb.save()).unwrap();
105 assert!(tb.partial_eq(&restored_tb));
106
107 tb.reduce(100);
109 let restored_tb = TokenBucket::restore((), &tb.save()).unwrap();
110 assert!(tb.partial_eq(&restored_tb));
111
112 tb.force_replenish(100);
114 let restored_tb = TokenBucket::restore((), &tb.save()).unwrap();
115 assert!(tb.partial_eq(&restored_tb));
116
117 let mut mem = vec![0; 4096];
119 Snapshot::new(tb.save())
120 .save(&mut mem.as_mut_slice())
121 .unwrap();
122
123 let restored_tb = TokenBucket::restore(
124 (),
125 &Snapshot::load_without_crc_check(mem.as_slice())
126 .unwrap()
127 .data,
128 )
129 .unwrap();
130 assert!(tb.partial_eq(&restored_tb));
131 }
132
133 #[test]
134 fn test_rate_limiter_persistence() {
135 let refill_time = 100_000;
136 let mut rate_limiter = RateLimiter::new(100, 0, refill_time, 10, 0, refill_time).unwrap();
137
138 let restored_rate_limiter =
140 RateLimiter::restore((), &rate_limiter.save()).expect("Unable to restore rate limiter");
141
142 assert!(
143 rate_limiter
144 .ops()
145 .unwrap()
146 .partial_eq(restored_rate_limiter.ops().unwrap())
147 );
148 assert!(
149 rate_limiter
150 .bandwidth()
151 .unwrap()
152 .partial_eq(restored_rate_limiter.bandwidth().unwrap())
153 );
154 assert_eq!(
155 restored_rate_limiter.timer_fd.get_state(),
156 TimerState::Disarmed
157 );
158
159 rate_limiter.consume(10, TokenType::Bytes);
161 rate_limiter.consume(10, TokenType::Ops);
162 let restored_rate_limiter =
163 RateLimiter::restore((), &rate_limiter.save()).expect("Unable to restore rate limiter");
164
165 assert!(
166 rate_limiter
167 .ops()
168 .unwrap()
169 .partial_eq(restored_rate_limiter.ops().unwrap())
170 );
171 assert!(
172 rate_limiter
173 .bandwidth()
174 .unwrap()
175 .partial_eq(restored_rate_limiter.bandwidth().unwrap())
176 );
177 assert_eq!(
178 restored_rate_limiter.timer_fd.get_state(),
179 TimerState::Disarmed
180 );
181
182 rate_limiter.consume(1000, TokenType::Bytes);
184 let restored_rate_limiter =
185 RateLimiter::restore((), &rate_limiter.save()).expect("Unable to restore rate limiter");
186
187 assert!(
188 rate_limiter
189 .ops()
190 .unwrap()
191 .partial_eq(restored_rate_limiter.ops().unwrap())
192 );
193 assert!(
194 rate_limiter
195 .bandwidth()
196 .unwrap()
197 .partial_eq(restored_rate_limiter.bandwidth().unwrap())
198 );
199
200 let mut mem = vec![0; 4096];
202 Snapshot::new(rate_limiter.save())
203 .save(&mut mem.as_mut_slice())
204 .unwrap();
205 let restored_rate_limiter = RateLimiter::restore(
206 (),
207 &Snapshot::load_without_crc_check(mem.as_slice())
208 .unwrap()
209 .data,
210 )
211 .unwrap();
212
213 assert!(
214 rate_limiter
215 .ops()
216 .unwrap()
217 .partial_eq(restored_rate_limiter.ops().unwrap())
218 );
219 assert!(
220 rate_limiter
221 .bandwidth()
222 .unwrap()
223 .partial_eq(restored_rate_limiter.bandwidth().unwrap())
224 );
225 }
226}