vmm/rate_limiter/
mod.rs

1// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::os::unix::io::{AsRawFd, RawFd};
5use std::time::{Duration, Instant};
6use std::{fmt, io};
7
8use timerfd::{ClockId, SetTimeFlags, TimerFd, TimerState};
9
10pub mod persist;
11
12#[derive(Debug, thiserror::Error, displaydoc::Display)]
13/// Describes the errors that may occur while handling rate limiter events.
14pub enum RateLimiterError {
15    /// The event handler was called spuriously: {0}
16    SpuriousRateLimiterEvent(&'static str),
17}
18
19// Interval at which the refill timer will run when limiter is at capacity.
20const REFILL_TIMER_INTERVAL_MS: u64 = 100;
21const TIMER_REFILL_STATE: TimerState =
22    TimerState::Oneshot(Duration::from_millis(REFILL_TIMER_INTERVAL_MS));
23
24const NANOSEC_IN_ONE_MILLISEC: u64 = 1_000_000;
25
26// Euclid's two-thousand-year-old algorithm for finding the greatest common divisor.
27#[cfg_attr(kani, kani::requires(x > 0 && y > 0))]
28#[cfg_attr(kani, kani::ensures(
29    |&result| result != 0
30        && x % result == 0
31        && y % result == 0
32))]
33fn gcd(x: u64, y: u64) -> u64 {
34    let mut x = x;
35    let mut y = y;
36    while y != 0 {
37        let t = y;
38        y = x % y;
39        x = t;
40    }
41    x
42}
43
44/// Enum describing the outcomes of a `reduce()` call on a `TokenBucket`.
45#[derive(Clone, Debug, PartialEq)]
46pub enum BucketReduction {
47    /// There are not enough tokens to complete the operation.
48    Failure,
49    /// A part of the available tokens have been consumed.
50    Success,
51    /// A number of tokens `inner` times larger than the bucket size have been consumed.
52    OverConsumption(f64),
53}
54
55/// TokenBucket provides a lower level interface to rate limiting with a
56/// configurable capacity, refill-rate and initial burst.
57#[derive(Clone, Debug, PartialEq, Eq)]
58pub struct TokenBucket {
59    // Bucket defining traits.
60    size: u64,
61    // Initial burst size.
62    initial_one_time_burst: u64,
63    // Complete refill time in milliseconds.
64    refill_time: u64,
65
66    // Internal state descriptors.
67
68    // Number of free initial tokens, that can be consumed at no cost.
69    one_time_burst: u64,
70    // Current token budget.
71    budget: u64,
72    // Last time this token bucket saw activity.
73    last_update: Instant,
74
75    // Fields used for pre-processing optimizations.
76    processed_capacity: u64,
77    processed_refill_time: u64,
78}
79
80impl TokenBucket {
81    /// Creates a `TokenBucket` wrapped in an `Option`.
82    ///
83    /// TokenBucket created is of `size` total capacity and takes `complete_refill_time_ms`
84    /// milliseconds to go from zero tokens to total capacity. The `one_time_burst` is initial
85    /// extra credit on top of total capacity, that does not replenish and which can be used
86    /// for an initial burst of data.
87    ///
88    /// If the `size` or the `complete refill time` are zero, then `None` is returned.
89    pub fn new(size: u64, one_time_burst: u64, complete_refill_time_ms: u64) -> Option<Self> {
90        // If either token bucket capacity or refill time is 0, disable limiting.
91        if size == 0 || complete_refill_time_ms == 0 {
92            return None;
93        }
94        // Formula for computing current refill amount:
95        // refill_token_count = (delta_time * size) / (complete_refill_time_ms * 1_000_000)
96        // In order to avoid overflows, simplify the fractions by computing greatest common divisor.
97
98        let complete_refill_time_ns =
99            complete_refill_time_ms.checked_mul(NANOSEC_IN_ONE_MILLISEC)?;
100        // Get the greatest common factor between `size` and `complete_refill_time_ns`.
101        let common_factor = gcd(size, complete_refill_time_ns);
102        // The division will be exact since `common_factor` is a factor of `size`.
103        let processed_capacity: u64 = size / common_factor;
104        // The division will be exact since `common_factor` is a factor of
105        // `complete_refill_time_ns`.
106        let processed_refill_time: u64 = complete_refill_time_ns / common_factor;
107
108        Some(TokenBucket {
109            size,
110            one_time_burst,
111            initial_one_time_burst: one_time_burst,
112            refill_time: complete_refill_time_ms,
113            // Start off full.
114            budget: size,
115            // Last updated is now.
116            last_update: Instant::now(),
117            processed_capacity,
118            processed_refill_time,
119        })
120    }
121
122    // Replenishes token bucket based on elapsed time. Should only be called internally by `Self`.
123    #[allow(clippy::cast_possible_truncation)]
124    fn auto_replenish(&mut self) {
125        // Compute time passed since last refill/update.
126        let now = Instant::now();
127        let time_delta = (now - self.last_update).as_nanos();
128
129        if time_delta >= u128::from(self.refill_time * NANOSEC_IN_ONE_MILLISEC) {
130            self.budget = self.size;
131            self.last_update = now;
132        } else {
133            // At each 'time_delta' nanoseconds the bucket should refill with:
134            // refill_amount = (time_delta * size) / (complete_refill_time_ms * 1_000_000)
135            // `processed_capacity` and `processed_refill_time` are the result of simplifying above
136            // fraction formula with their greatest-common-factor.
137
138            // In the constructor, we assured that (self.refill_time * NANOSEC_IN_ONE_MILLISEC)
139            // fits into a u64 That means, at this point we know that time_delta <
140            // u64::MAX. Since all other values here are u64, this assures that u128
141            // multiplication cannot overflow.
142            let processed_capacity = u128::from(self.processed_capacity);
143            let processed_refill_time = u128::from(self.processed_refill_time);
144
145            let tokens = (time_delta * processed_capacity) / processed_refill_time;
146
147            // We increment `self.last_update` by the minimum time required to generate `tokens`, in
148            // the case where we have the time to generate `1.8` tokens but only
149            // generate `x` tokens due to integer arithmetic this will carry the time
150            // required to generate 0.8th of a token over to the next call, such that if
151            // the next call where to generate `2.3` tokens it would instead
152            // generate `3.1` tokens. This minimizes dropping tokens at high frequencies.
153            // We want the integer division here to round up instead of down (as if we round down,
154            // we would allow some fraction of a nano second to be used twice, allowing
155            // for the generation of one extra token in extreme circumstances).
156            let mut time_adjustment = tokens * processed_refill_time / processed_capacity;
157            if tokens * processed_refill_time % processed_capacity != 0 {
158                time_adjustment += 1;
159            }
160
161            // Ensure that we always generate as many tokens as we can: assert that the "unused"
162            // part of time_delta is less than the time it would take to generate a
163            // single token (= processed_refill_time / processed_capacity)
164            debug_assert!(time_adjustment <= time_delta);
165            debug_assert!(
166                (time_delta - time_adjustment) * processed_capacity <= processed_refill_time
167            );
168
169            // time_adjustment is at most time_delta, and since time_delta <= u64::MAX, this cast is
170            // fine
171            self.last_update += Duration::from_nanos(time_adjustment as u64);
172            self.budget = std::cmp::min(self.budget.saturating_add(tokens as u64), self.size);
173        }
174    }
175
176    /// Attempts to consume `tokens` from the bucket and returns whether the action succeeded.
177    pub fn reduce(&mut self, mut tokens: u64) -> BucketReduction {
178        // First things first: consume the one-time-burst budget.
179        if self.one_time_burst > 0 {
180            // We still have burst budget for *all* tokens requests.
181            if self.one_time_burst >= tokens {
182                self.one_time_burst -= tokens;
183                self.last_update = Instant::now();
184                // No need to continue to the refill process, we still have burst budget to consume
185                // from.
186                return BucketReduction::Success;
187            } else {
188                // We still have burst budget for *some* of the tokens requests.
189                // The tokens left unfulfilled will be consumed from current `self.budget`.
190                tokens -= self.one_time_burst;
191                self.one_time_burst = 0;
192            }
193        }
194
195        if tokens > self.budget {
196            // Hit the bucket bottom, let's auto-replenish and try again.
197            self.auto_replenish();
198
199            // This operation requests a bandwidth higher than the bucket size
200            if tokens > self.size {
201                crate::logger::error!(
202                    "Consumed {} tokens from bucket of size {}",
203                    tokens,
204                    self.size
205                );
206                // Empty the bucket and report an overconsumption of
207                // (remaining tokens / size) times larger than the bucket size
208                tokens -= self.budget;
209                self.budget = 0;
210                return BucketReduction::OverConsumption(tokens as f64 / self.size as f64);
211            }
212
213            if tokens > self.budget {
214                // Still not enough tokens, consume() fails, return false.
215                return BucketReduction::Failure;
216            }
217        }
218
219        self.budget -= tokens;
220        BucketReduction::Success
221    }
222
223    /// "Manually" adds tokens to bucket.
224    pub fn force_replenish(&mut self, tokens: u64) {
225        // This means we are still during the burst interval.
226        // Of course there is a very small chance  that the last reduce() also used up burst
227        // budget which should now be replenished, but for performance and code-complexity
228        // reasons we're just gonna let that slide since it's practically inconsequential.
229        if self.one_time_burst > 0 {
230            self.one_time_burst = std::cmp::min(
231                self.one_time_burst.saturating_add(tokens),
232                self.initial_one_time_burst,
233            );
234            return;
235        }
236        self.budget = std::cmp::min(self.budget.saturating_add(tokens), self.size);
237    }
238
239    /// Returns the capacity of the token bucket.
240    pub fn capacity(&self) -> u64 {
241        self.size
242    }
243
244    /// Returns the remaining one time burst budget.
245    pub fn one_time_burst(&self) -> u64 {
246        self.one_time_burst
247    }
248
249    /// Returns the time in milliseconds required to to completely fill the bucket.
250    pub fn refill_time_ms(&self) -> u64 {
251        self.refill_time
252    }
253
254    /// Returns the current budget (one time burst allowance notwithstanding).
255    pub fn budget(&self) -> u64 {
256        self.budget
257    }
258
259    /// Returns the initially configured one time burst budget.
260    pub fn initial_one_time_burst(&self) -> u64 {
261        self.initial_one_time_burst
262    }
263}
264
265/// Enum that describes the type of token used.
266#[derive(Debug)]
267pub enum TokenType {
268    /// Token type used for bandwidth limiting.
269    Bytes,
270    /// Token type used for operations/second limiting.
271    Ops,
272}
273
274/// Enum that describes the type of token bucket update.
275#[derive(Debug)]
276pub enum BucketUpdate {
277    /// No Update - same as before.
278    None,
279    /// Rate Limiting is disabled on this bucket.
280    Disabled,
281    /// Rate Limiting enabled with updated bucket.
282    Update(TokenBucket),
283}
284
285/// Rate Limiter that works on both bandwidth and ops/s limiting.
286///
287/// Bandwidth (bytes/s) and ops/s limiting can be used at the same time or individually.
288///
289/// Implementation uses a single timer through TimerFd to refresh either or
290/// both token buckets.
291///
292/// Its internal buckets are 'passively' replenished as they're being used (as
293/// part of `consume()` operations).
294/// A timer is enabled and used to 'actively' replenish the token buckets when
295/// limiting is in effect and `consume()` operations are disabled.
296///
297/// RateLimiters will generate events on the FDs provided by their `AsRawFd` trait
298/// implementation. These events are meant to be consumed by the user of this struct.
299/// On each such event, the user must call the `event_handler()` method.
300pub struct RateLimiter {
301    bandwidth: Option<TokenBucket>,
302    ops: Option<TokenBucket>,
303
304    timer_fd: TimerFd,
305    // Internal flag that quickly determines timer state.
306    timer_active: bool,
307}
308
309impl PartialEq for RateLimiter {
310    fn eq(&self, other: &RateLimiter) -> bool {
311        self.bandwidth == other.bandwidth && self.ops == other.ops
312    }
313}
314
315impl fmt::Debug for RateLimiter {
316    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
317        write!(
318            f,
319            "RateLimiter {{ bandwidth: {:?}, ops: {:?} }}",
320            self.bandwidth, self.ops
321        )
322    }
323}
324
325impl RateLimiter {
326    /// Creates a new Rate Limiter that can limit on both bytes/s and ops/s.
327    ///
328    /// # Arguments
329    ///
330    /// * `bytes_total_capacity` - the total capacity of the `TokenType::Bytes` token bucket.
331    /// * `bytes_one_time_burst` - initial extra credit on top of `bytes_total_capacity`, that does
332    ///   not replenish and which can be used for an initial burst of data.
333    /// * `bytes_complete_refill_time_ms` - number of milliseconds for the `TokenType::Bytes` token
334    ///   bucket to go from zero Bytes to `bytes_total_capacity` Bytes.
335    /// * `ops_total_capacity` - the total capacity of the `TokenType::Ops` token bucket.
336    /// * `ops_one_time_burst` - initial extra credit on top of `ops_total_capacity`, that does not
337    ///   replenish and which can be used for an initial burst of data.
338    /// * `ops_complete_refill_time_ms` - number of milliseconds for the `TokenType::Ops` token
339    ///   bucket to go from zero Ops to `ops_total_capacity` Ops.
340    ///
341    /// If either bytes/ops *size* or *refill_time* are **zero**, the limiter
342    /// is **disabled** for that respective token type.
343    ///
344    /// # Errors
345    ///
346    /// If the timerfd creation fails, an error is returned.
347    pub fn new(
348        bytes_total_capacity: u64,
349        bytes_one_time_burst: u64,
350        bytes_complete_refill_time_ms: u64,
351        ops_total_capacity: u64,
352        ops_one_time_burst: u64,
353        ops_complete_refill_time_ms: u64,
354    ) -> io::Result<Self> {
355        let bytes_token_bucket = TokenBucket::new(
356            bytes_total_capacity,
357            bytes_one_time_burst,
358            bytes_complete_refill_time_ms,
359        );
360
361        let ops_token_bucket = TokenBucket::new(
362            ops_total_capacity,
363            ops_one_time_burst,
364            ops_complete_refill_time_ms,
365        );
366
367        // We'll need a timer_fd, even if our current config effectively disables rate limiting,
368        // because `Self::update_buckets()` might re-enable it later, and we might be
369        // seccomp-blocked from creating the timer_fd at that time.
370        let timer_fd = TimerFd::new_custom(ClockId::Monotonic, true, true)?;
371
372        Ok(RateLimiter {
373            bandwidth: bytes_token_bucket,
374            ops: ops_token_bucket,
375            timer_fd,
376            timer_active: false,
377        })
378    }
379
380    // Arm the timer of the rate limiter with the provided `TimerState`.
381    fn activate_timer(&mut self, timer_state: TimerState) {
382        // Register the timer; don't care about its previous state
383        self.timer_fd.set_state(timer_state, SetTimeFlags::Default);
384        self.timer_active = true;
385    }
386
387    /// Attempts to consume tokens and returns whether that is possible.
388    ///
389    /// If rate limiting is disabled on provided `token_type`, this function will always succeed.
390    pub fn consume(&mut self, tokens: u64, token_type: TokenType) -> bool {
391        // If the timer is active, we can't consume tokens from any bucket and the function fails.
392        if self.timer_active {
393            return false;
394        }
395
396        // Identify the required token bucket.
397        let token_bucket = match token_type {
398            TokenType::Bytes => self.bandwidth.as_mut(),
399            TokenType::Ops => self.ops.as_mut(),
400        };
401        // Try to consume from the token bucket.
402        if let Some(bucket) = token_bucket {
403            let refill_time = bucket.refill_time_ms();
404            match bucket.reduce(tokens) {
405                // When we report budget is over, there will be no further calls here,
406                // register a timer to replenish the bucket and resume processing;
407                // make sure there is only one running timer for this limiter.
408                BucketReduction::Failure => {
409                    if !self.timer_active {
410                        self.activate_timer(TIMER_REFILL_STATE);
411                    }
412                    false
413                }
414                // The operation succeeded and further calls can be made.
415                BucketReduction::Success => true,
416                // The operation succeeded as the tokens have been consumed
417                // but the timer still needs to be armed.
418                BucketReduction::OverConsumption(ratio) => {
419                    // The operation "borrowed" a number of tokens `ratio` times
420                    // greater than the size of the bucket, and since it takes
421                    // `refill_time` milliseconds to fill an empty bucket, in
422                    // order to enforce the bandwidth limit we need to prevent
423                    // further calls to the rate limiter for
424                    // `ratio * refill_time` milliseconds.
425                    // The conversion should be safe because the ratio is positive.
426                    #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
427                    self.activate_timer(TimerState::Oneshot(Duration::from_millis(
428                        (ratio * refill_time as f64) as u64,
429                    )));
430                    true
431                }
432            }
433        } else {
434            // If bucket is not present rate limiting is disabled on token type,
435            // consume() will always succeed.
436            true
437        }
438    }
439
440    /// Adds tokens of `token_type` to their respective bucket.
441    ///
442    /// Can be used to *manually* add tokens to a bucket. Useful for reverting a
443    /// `consume()` if needed.
444    pub fn manual_replenish(&mut self, tokens: u64, token_type: TokenType) {
445        // Identify the required token bucket.
446        let token_bucket = match token_type {
447            TokenType::Bytes => self.bandwidth.as_mut(),
448            TokenType::Ops => self.ops.as_mut(),
449        };
450        // Add tokens to the token bucket.
451        if let Some(bucket) = token_bucket {
452            bucket.force_replenish(tokens);
453        }
454    }
455
456    /// Returns whether this rate limiter is blocked.
457    ///
458    /// The limiter 'blocks' when a `consume()` operation fails because there was not enough
459    /// budget for it.
460    /// An event will be generated on the exported FD when the limiter 'unblocks'.
461    pub fn is_blocked(&self) -> bool {
462        self.timer_active
463    }
464
465    /// This function needs to be called every time there is an event on the
466    /// FD provided by this object's `AsRawFd` trait implementation.
467    ///
468    /// # Errors
469    ///
470    /// If the rate limiter is disabled or is not blocked, an error is returned.
471    pub fn event_handler(&mut self) -> Result<(), RateLimiterError> {
472        match self.timer_fd.read() {
473            0 => Err(RateLimiterError::SpuriousRateLimiterEvent(
474                "Rate limiter event handler called without a present timer",
475            )),
476            _ => {
477                self.timer_active = false;
478                Ok(())
479            }
480        }
481    }
482
483    /// Updates the parameters of the token buckets associated with this RateLimiter.
484    // TODO: Please note that, right now, the buckets become full after being updated.
485    pub fn update_buckets(&mut self, bytes: BucketUpdate, ops: BucketUpdate) {
486        match bytes {
487            BucketUpdate::Disabled => self.bandwidth = None,
488            BucketUpdate::Update(tb) => self.bandwidth = Some(tb),
489            BucketUpdate::None => (),
490        };
491        match ops {
492            BucketUpdate::Disabled => self.ops = None,
493            BucketUpdate::Update(tb) => self.ops = Some(tb),
494            BucketUpdate::None => (),
495        };
496    }
497
498    /// Returns an immutable view of the inner bandwidth token bucket.
499    pub fn bandwidth(&self) -> Option<&TokenBucket> {
500        self.bandwidth.as_ref()
501    }
502
503    /// Returns an immutable view of the inner ops token bucket.
504    pub fn ops(&self) -> Option<&TokenBucket> {
505        self.ops.as_ref()
506    }
507}
508
509impl AsRawFd for RateLimiter {
510    /// Provides a FD which needs to be monitored for POLLIN events.
511    ///
512    /// This object's `event_handler()` method must be called on such events.
513    ///
514    /// Will return a negative value if rate limiting is disabled on both
515    /// token types.
516    fn as_raw_fd(&self) -> RawFd {
517        self.timer_fd.as_raw_fd()
518    }
519}
520
521impl Default for RateLimiter {
522    /// Default RateLimiter is a no-op limiter with infinite budget.
523    fn default() -> Self {
524        // Safe to unwrap since this will not attempt to create timer_fd.
525        RateLimiter::new(0, 0, 0, 0, 0, 0).expect("Failed to build default RateLimiter")
526    }
527}
528
529#[cfg(kani)]
530#[allow(dead_code)] // Avoid warning when using stubs.
531mod verification {
532    use std::time::Instant;
533
534    use super::*;
535
536    mod stubs {
537        use std::time::Instant;
538
539        use crate::rate_limiter::TokenBucket;
540
541        // On Unix, the Rust Standard Library defines Instants as
542        //
543        // struct Instance(struct inner::Instant {
544        //     t: struct Timespec {
545        //         tv_sec: i64,
546        //         tv_nsec: struct Nanoseconds(u32),
547        //     }
548        // }
549        //
550        // This is not really repr-compatible with the below, as the structs (apart from
551        // `Nanoseconds`) are repr(Rust), but currently this seems to work.
552        #[repr(C)]
553        struct InstantStub {
554            tv_sec: i64,
555            tv_nsec: u32,
556        }
557
558        // The last value returned by this stub, in nano seconds. We keep these variables separately
559        // for Kani performance reasons (just counting nanos and then doing division/modulo
560        // to get seconds/nanos is slow as those operations are very difficult for Kani's
561        // underlying SAT solvers).
562        static mut LAST_SECONDS: i64 = 0;
563        static mut LAST_NANOS: u32 = 0;
564
565        /// Stubs out `std::time::Instant::now` to return non-deterministic instances that are
566        /// non-decreasing. The first value produced by this stub will always be 0. This is
567        /// because generally harnesses only care about the delta between instants i1 and i2, which
568        /// is arbitrary as long as at least one of i1, i2 is non-deterministic. Therefore,
569        /// hardcoding one of the instances to be 0 brings a performance improvement. Should
570        /// a harness loose generality due to the first Instant::now() call returning 0, add a
571        /// dummy call to Instant::now() to the top of the harness to consume the 0 value. All
572        /// subsequent calls will then result in non-deterministic values.
573        fn instant_now() -> Instant {
574            // Instants are non-decreasing.
575            // See https://doc.rust-lang.org/std/time/struct.Instant.html.
576            // upper bound on seconds to prevent scenarios involving clock overflow.
577            let next_seconds = kani::any_where(|n| *n >= unsafe { LAST_SECONDS });
578            let next_nanos = kani::any_where(|n| *n < 1_000_000_000); // rustc intrinsic bound
579
580            if next_seconds == unsafe { LAST_SECONDS } {
581                kani::assume(next_nanos >= unsafe { LAST_NANOS });
582            }
583
584            let to_return = next_instant_now();
585
586            unsafe {
587                LAST_SECONDS = next_seconds;
588                LAST_NANOS = next_nanos;
589            }
590
591            to_return
592        }
593
594        pub(super) fn next_instant_now() -> Instant {
595            let stub = InstantStub {
596                tv_sec: unsafe { LAST_SECONDS },
597                tv_nsec: unsafe { LAST_NANOS },
598            };
599
600            // In normal rust code, this would not be safe, as the compiler can re-order the fields
601            // However, kani will never run any transformations on the code, so this is safe. This
602            // is because kani doesn't use rustc/llvm to compile down to bytecode, but instead
603            // transpiles unoptimized rust MIR to goto-programs, which are then fed to CMBC.
604            unsafe { std::mem::transmute(stub) }
605        }
606
607        /// Stubs out `TokenBucket::auto_replenish` by simply filling up the bucket by a
608        /// non-deterministic amount.
609        fn token_bucket_auto_replenish(this: &mut TokenBucket) {
610            this.budget += kani::any_where::<u64, _>(|&n| n <= this.size - this.budget);
611        }
612    }
613
614    impl TokenBucket {
615        /// Functions checking that the general invariants of a TokenBucket are upheld
616        fn is_valid(&self) -> bool {
617            self.size != 0
618                && self.refill_time != 0
619                // The token budget can never exceed the bucket's size
620                && self.budget <= self.size
621                // The burst budget never exceeds its initial value
622                && self.one_time_burst <= self.initial_one_time_burst
623                // While burst budget is available, no tokens from the normal budget are consumed.
624                && (self.one_time_burst == 0 || self.budget == self.size)
625        }
626    }
627
628    impl kani::Arbitrary for TokenBucket {
629        fn any() -> TokenBucket {
630            let bucket = TokenBucket::new(kani::any(), kani::any(), kani::any());
631            kani::assume(bucket.is_some());
632            let mut bucket = bucket.unwrap();
633
634            // Adjust the budgets non-deterministically to simulate that the bucket has been "in
635            // use" already
636            bucket.budget = kani::any();
637            bucket.one_time_burst = kani::any();
638
639            kani::assume(bucket.is_valid());
640
641            bucket
642        }
643    }
644
645    #[kani::proof]
646    #[kani::stub(std::time::Instant::now, stubs::instant_now)]
647    fn verify_instant_stub_non_decreasing() {
648        let early = Instant::now();
649        let late = Instant::now();
650        assert!(early <= late);
651    }
652
653    // Euclid algorithm has runtime O(log(min(x,y))) -> kani::unwind(log(MAX)) should be enough.
654    #[kani::proof_for_contract(gcd)]
655    #[kani::unwind(64)]
656    #[kani::solver(cadical)]
657    fn gcd_contract_harness() {
658        const MAX: u64 = 64;
659        let x = kani::any_where(|&x| x < MAX);
660        let y = kani::any_where(|&y| y < MAX);
661        let gcd = super::gcd(x, y);
662        // Most assertions are unnecessary as they are proved as part of the
663        // contract. However for simplification the contract only enforces that
664        // the result is *a* divisor, not necessarily the smallest one, so we
665        // check that here manually.
666        if gcd != 0 {
667            let w = kani::any_where(|&w| w > 0 && x % w == 0 && y % w == 0);
668            assert!(gcd >= w);
669        }
670    }
671
672    #[kani::proof]
673    #[kani::stub(std::time::Instant::now, stubs::instant_now)]
674    #[kani::stub_verified(gcd)]
675    #[kani::solver(cadical)]
676    fn verify_token_bucket_new() {
677        let size = kani::any();
678        let one_time_burst = kani::any();
679        let complete_refill_time_ms = kani::any();
680
681        // Checks if the `TokenBucket` is created with invalid inputs, the result is always `None`.
682        match TokenBucket::new(size, one_time_burst, complete_refill_time_ms) {
683            None => assert!(
684                size == 0
685                    || complete_refill_time_ms == 0
686                    || complete_refill_time_ms > u64::MAX / NANOSEC_IN_ONE_MILLISEC
687            ),
688            Some(bucket) => assert!(bucket.is_valid()),
689        }
690    }
691
692    #[kani::proof]
693    #[kani::unwind(1)] // enough to unwind the recursion at `Timespec::sub_timespec`
694    #[kani::stub(std::time::Instant::now, stubs::instant_now)]
695    #[kani::stub_verified(gcd)]
696    fn verify_token_bucket_auto_replenish() {
697        const MAX_BUCKET_SIZE: u64 = 15;
698        const MAX_REFILL_TIME: u64 = 15;
699
700        // Create a non-deterministic `TokenBucket`. This internally calls `Instant::now()`, which
701        // is stubbed to always return 0 on its first call. We can make this simplification
702        // here, as `auto_replenish` only cares about the time delta between two consecutive
703        // calls. This speeds up the verification significantly.
704        let size = kani::any_where(|n| *n < MAX_BUCKET_SIZE && *n != 0);
705        let complete_refill_time_ms = kani::any_where(|n| *n < MAX_REFILL_TIME && *n != 0);
706        // `auto_replenish` doesn't use `one_time_burst`
707        let mut bucket: TokenBucket = TokenBucket::new(size, 0, complete_refill_time_ms).unwrap();
708
709        bucket.auto_replenish();
710
711        assert!(bucket.is_valid());
712    }
713
714    #[kani::proof]
715    #[kani::stub(std::time::Instant::now, stubs::instant_now)]
716    #[kani::stub(TokenBucket::auto_replenish, stubs::token_bucket_auto_replenish)]
717    #[kani::stub_verified(gcd)]
718    #[kani::solver(cadical)]
719    fn verify_token_bucket_reduce() {
720        let mut token_bucket: TokenBucket = kani::any();
721
722        let old_token_bucket = token_bucket.clone();
723
724        let tokens = kani::any();
725        let result = token_bucket.reduce(tokens);
726
727        assert!(token_bucket.is_valid());
728        assert!(token_bucket.one_time_burst <= old_token_bucket.one_time_burst);
729
730        // Initial burst always gets used up before budget. Read assertion as implication, i.e.,
731        // `token_bucket.budget != old_token_bucket.budget => token_bucket.one_time_burst == 0`.
732        assert!(token_bucket.budget == old_token_bucket.budget || token_bucket.one_time_burst == 0);
733
734        // If reduction failed, bucket state should not change.
735        if result == BucketReduction::Failure {
736            // In case of a failure, no budget should have been consumed. However, since `reduce`
737            // attempts to call `auto_replenish`, the budget could actually have
738            // increased.
739            assert!(token_bucket.budget >= old_token_bucket.budget);
740            assert!(token_bucket.one_time_burst == old_token_bucket.one_time_burst);
741
742            // Ensure that it is possible to trigger the BucketReduction::Failure case at all.
743            // kani::cover makes verification fail if no possible execution path reaches
744            // this line.
745            kani::cover!();
746        }
747    }
748
749    #[kani::proof]
750    #[kani::stub(std::time::Instant::now, stubs::instant_now)]
751    #[kani::stub_verified(gcd)]
752    #[kani::stub(TokenBucket::auto_replenish, stubs::token_bucket_auto_replenish)]
753    fn verify_token_bucket_force_replenish() {
754        let mut token_bucket: TokenBucket = kani::any();
755
756        token_bucket.reduce(kani::any());
757        let reduced_budget = token_bucket.budget;
758        let reduced_burst = token_bucket.one_time_burst;
759
760        let to_replenish = kani::any();
761
762        token_bucket.force_replenish(to_replenish);
763
764        assert!(token_bucket.is_valid());
765        assert!(token_bucket.budget >= reduced_budget);
766        assert!(token_bucket.one_time_burst >= reduced_burst);
767    }
768}
769
770#[cfg(test)]
771pub(crate) mod tests {
772    use std::thread;
773    use std::time::Duration;
774
775    use super::*;
776
777    // Define custom refill interval to be a bit bigger. This will help
778    // in tests which wait for a limiter refill in 2 stages. This will make it so
779    // second wait will always result in the limiter being refilled. Otherwise
780    // there is a chance for a race condition between limiter refilling and limiter
781    // checking.
782    const TEST_REFILL_TIMER_INTERVAL_MS: u64 = REFILL_TIMER_INTERVAL_MS + 10;
783
784    impl TokenBucket {
785        // Resets the token bucket: budget set to max capacity and last-updated set to now.
786        fn reset(&mut self) {
787            self.budget = self.size;
788            self.last_update = Instant::now();
789        }
790
791        fn get_last_update(&self) -> &Instant {
792            &self.last_update
793        }
794
795        fn get_processed_capacity(&self) -> u64 {
796            self.processed_capacity
797        }
798
799        fn get_processed_refill_time(&self) -> u64 {
800            self.processed_refill_time
801        }
802
803        // After a restore, we cannot be certain that the last_update field has the same value.
804        pub(crate) fn partial_eq(&self, other: &TokenBucket) -> bool {
805            (other.capacity() == self.capacity())
806                && (other.one_time_burst() == self.one_time_burst())
807                && (other.refill_time_ms() == self.refill_time_ms())
808                && (other.budget() == self.budget())
809        }
810    }
811
812    impl RateLimiter {
813        fn get_token_bucket(&self, token_type: TokenType) -> Option<&TokenBucket> {
814            match token_type {
815                TokenType::Bytes => self.bandwidth.as_ref(),
816                TokenType::Ops => self.ops.as_ref(),
817            }
818        }
819    }
820
821    #[test]
822    fn test_token_bucket_auto_replenish_one() {
823        // These values will give 1 token every 100 milliseconds
824        const SIZE: u64 = 10;
825        const TIME: u64 = 1000;
826        let mut tb = TokenBucket::new(SIZE, 0, TIME).unwrap();
827        tb.reduce(SIZE);
828        assert_eq!(tb.budget(), 0);
829
830        // Auto-replenishing after 10 milliseconds should not yield any tokens
831        thread::sleep(Duration::from_millis(10));
832        tb.auto_replenish();
833        assert_eq!(tb.budget(), 0);
834
835        // Neither after 20.
836        thread::sleep(Duration::from_millis(10));
837        tb.auto_replenish();
838        assert_eq!(tb.budget(), 0);
839
840        // We should get 1 token after 100 millis
841        thread::sleep(Duration::from_millis(80));
842        tb.auto_replenish();
843        assert_eq!(tb.budget(), 1);
844
845        // So, 5 after 500 millis
846        thread::sleep(Duration::from_millis(400));
847        tb.auto_replenish();
848        assert_eq!(tb.budget(), 5);
849
850        // And be fully replenished after 1 second.
851        // Wait more here to make sure we do not overshoot
852        thread::sleep(Duration::from_millis(1000));
853        tb.auto_replenish();
854        assert_eq!(tb.budget(), 10);
855    }
856
857    #[test]
858    fn test_token_bucket_auto_replenish_two() {
859        const SIZE: u64 = 1000;
860        const TIME: u64 = 1000;
861        let time = Duration::from_millis(TIME);
862
863        let mut tb = TokenBucket::new(SIZE, 0, TIME).unwrap();
864        tb.reduce(SIZE);
865        assert_eq!(tb.budget(), 0);
866
867        let now = Instant::now();
868        while now.elapsed() < time {
869            tb.auto_replenish();
870        }
871        tb.auto_replenish();
872        assert_eq!(tb.budget(), SIZE);
873    }
874
875    #[test]
876    fn test_token_bucket_create() {
877        let before = Instant::now();
878        let tb = TokenBucket::new(1000, 0, 1000).unwrap();
879        assert_eq!(tb.capacity(), 1000);
880        assert_eq!(tb.budget(), 1000);
881        assert!(*tb.get_last_update() >= before);
882        let after = Instant::now();
883        assert!(*tb.get_last_update() <= after);
884        assert_eq!(tb.get_processed_capacity(), 1);
885        assert_eq!(tb.get_processed_refill_time(), 1_000_000);
886
887        // Verify invalid bucket configurations result in `None`.
888        assert!(TokenBucket::new(0, 1234, 1000).is_none());
889        assert!(TokenBucket::new(100, 1234, 0).is_none());
890        assert!(TokenBucket::new(0, 1234, 0).is_none());
891    }
892
893    #[test]
894    fn test_token_bucket_preprocess() {
895        let tb = TokenBucket::new(1000, 0, 1000).unwrap();
896        assert_eq!(tb.get_processed_capacity(), 1);
897        assert_eq!(tb.get_processed_refill_time(), NANOSEC_IN_ONE_MILLISEC);
898
899        let thousand = 1000;
900        let tb = TokenBucket::new(3 * 7 * 11 * 19 * thousand, 0, 7 * 11 * 13 * 17).unwrap();
901        assert_eq!(tb.get_processed_capacity(), 3 * 19);
902        assert_eq!(
903            tb.get_processed_refill_time(),
904            13 * 17 * (NANOSEC_IN_ONE_MILLISEC / thousand)
905        );
906    }
907
908    #[test]
909    fn test_token_bucket_reduce() {
910        // token bucket with capacity 1000 and refill time of 1000 milliseconds
911        // allowing rate of 1 token/ms.
912        let capacity = 1000;
913        let refill_ms = 1000;
914        let mut tb = TokenBucket::new(capacity, 0, refill_ms).unwrap();
915
916        assert_eq!(tb.reduce(123), BucketReduction::Success);
917        assert_eq!(tb.budget(), capacity - 123);
918        assert_eq!(tb.reduce(capacity), BucketReduction::Failure);
919
920        // token bucket with capacity 1000 and refill time of 1000 milliseconds
921        let mut tb = TokenBucket::new(1000, 1100, 1000).unwrap();
922        // safely assuming the thread can run these 3 commands in less than 500ms
923        assert_eq!(tb.reduce(1000), BucketReduction::Success);
924        assert_eq!(tb.one_time_burst(), 100);
925        assert_eq!(tb.reduce(500), BucketReduction::Success);
926        assert_eq!(tb.one_time_burst(), 0);
927        assert_eq!(tb.reduce(500), BucketReduction::Success);
928        assert_eq!(tb.reduce(500), BucketReduction::Failure);
929        thread::sleep(Duration::from_millis(500));
930        assert_eq!(tb.reduce(500), BucketReduction::Success);
931        thread::sleep(Duration::from_millis(1000));
932        assert_eq!(tb.reduce(2500), BucketReduction::OverConsumption(1.5));
933
934        let before = Instant::now();
935        tb.reset();
936        assert_eq!(tb.capacity(), 1000);
937        assert_eq!(tb.budget(), 1000);
938        assert!(*tb.get_last_update() >= before);
939        let after = Instant::now();
940        assert!(*tb.get_last_update() <= after);
941    }
942
943    #[test]
944    fn test_rate_limiter_default() {
945        let mut l = RateLimiter::default();
946
947        // limiter should not be blocked
948        assert!(!l.is_blocked());
949        // limiter should be disabled so consume(whatever) should work
950        assert!(l.consume(u64::MAX, TokenType::Ops));
951        assert!(l.consume(u64::MAX, TokenType::Bytes));
952        // calling the handler without there having been an event should error
953        l.event_handler().unwrap_err();
954        assert_eq!(
955            format!("{:?}", l.event_handler().err().unwrap()),
956            "SpuriousRateLimiterEvent(\"Rate limiter event handler called without a present \
957             timer\")"
958        );
959    }
960
961    #[test]
962    fn test_rate_limiter_new() {
963        let l = RateLimiter::new(1000, 1001, 1002, 1003, 1004, 1005).unwrap();
964
965        let bw = l.bandwidth.unwrap();
966        assert_eq!(bw.capacity(), 1000);
967        assert_eq!(bw.one_time_burst(), 1001);
968        assert_eq!(bw.refill_time_ms(), 1002);
969        assert_eq!(bw.budget(), 1000);
970
971        let ops = l.ops.unwrap();
972        assert_eq!(ops.capacity(), 1003);
973        assert_eq!(ops.one_time_burst(), 1004);
974        assert_eq!(ops.refill_time_ms(), 1005);
975        assert_eq!(ops.budget(), 1003);
976    }
977
978    #[test]
979    fn test_rate_limiter_manual_replenish() {
980        // rate limiter with limit of 1000 bytes/s and 1000 ops/s
981        let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
982
983        // consume 123 bytes
984        assert!(l.consume(123, TokenType::Bytes));
985        l.manual_replenish(23, TokenType::Bytes);
986        {
987            let bytes_tb = l.get_token_bucket(TokenType::Bytes).unwrap();
988            assert_eq!(bytes_tb.budget(), 900);
989        }
990        // consume 123 ops
991        assert!(l.consume(123, TokenType::Ops));
992        l.manual_replenish(23, TokenType::Ops);
993        {
994            let bytes_tb = l.get_token_bucket(TokenType::Ops).unwrap();
995            assert_eq!(bytes_tb.budget(), 900);
996        }
997    }
998
999    #[test]
1000    fn test_rate_limiter_bandwidth() {
1001        // rate limiter with limit of 1000 bytes/s
1002        let mut l = RateLimiter::new(1000, 0, 1000, 0, 0, 0).unwrap();
1003
1004        // limiter should not be blocked
1005        assert!(!l.is_blocked());
1006        // raw FD for this disabled should be valid
1007        assert!(l.as_raw_fd() > 0);
1008
1009        // ops/s limiter should be disabled so consume(whatever) should work
1010        assert!(l.consume(u64::MAX, TokenType::Ops));
1011
1012        // do full 1000 bytes
1013        assert!(l.consume(1000, TokenType::Bytes));
1014        // try and fail on another 100
1015        assert!(!l.consume(100, TokenType::Bytes));
1016        // since consume failed, limiter should be blocked now
1017        assert!(l.is_blocked());
1018        // wait half the timer period
1019        thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
1020        // limiter should still be blocked
1021        assert!(l.is_blocked());
1022        // wait the other half of the timer period
1023        thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
1024        // the timer_fd should have an event on it by now
1025        l.event_handler().unwrap();
1026        // limiter should now be unblocked
1027        assert!(!l.is_blocked());
1028        // try and succeed on another 100 bytes this time
1029        assert!(l.consume(100, TokenType::Bytes));
1030    }
1031
1032    #[test]
1033    fn test_rate_limiter_ops() {
1034        // rate limiter with limit of 1000 ops/s
1035        let mut l = RateLimiter::new(0, 0, 0, 1000, 0, 1000).unwrap();
1036
1037        // limiter should not be blocked
1038        assert!(!l.is_blocked());
1039        // raw FD for this disabled should be valid
1040        assert!(l.as_raw_fd() > 0);
1041
1042        // bytes/s limiter should be disabled so consume(whatever) should work
1043        assert!(l.consume(u64::MAX, TokenType::Bytes));
1044
1045        // do full 1000 ops
1046        assert!(l.consume(1000, TokenType::Ops));
1047        // try and fail on another 100
1048        assert!(!l.consume(100, TokenType::Ops));
1049        // since consume failed, limiter should be blocked now
1050        assert!(l.is_blocked());
1051        // wait half the timer period
1052        thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
1053        // limiter should still be blocked
1054        assert!(l.is_blocked());
1055        // wait the other half of the timer period
1056        thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
1057        // the timer_fd should have an event on it by now
1058        l.event_handler().unwrap();
1059        // limiter should now be unblocked
1060        assert!(!l.is_blocked());
1061        // try and succeed on another 100 ops this time
1062        assert!(l.consume(100, TokenType::Ops));
1063    }
1064
1065    #[test]
1066    fn test_rate_limiter_full() {
1067        // rate limiter with limit of 1000 bytes/s and 1000 ops/s
1068        let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
1069
1070        // limiter should not be blocked
1071        assert!(!l.is_blocked());
1072        // raw FD for this disabled should be valid
1073        assert!(l.as_raw_fd() > 0);
1074
1075        // do full 1000 bytes
1076        assert!(l.consume(1000, TokenType::Ops));
1077        // do full 1000 bytes
1078        assert!(l.consume(1000, TokenType::Bytes));
1079        // try and fail on another 100 ops
1080        assert!(!l.consume(100, TokenType::Ops));
1081        // try and fail on another 100 bytes
1082        assert!(!l.consume(100, TokenType::Bytes));
1083        // since consume failed, limiter should be blocked now
1084        assert!(l.is_blocked());
1085        // wait half the timer period
1086        thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
1087        // limiter should still be blocked
1088        assert!(l.is_blocked());
1089        // wait the other half of the timer period
1090        thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
1091        // the timer_fd should have an event on it by now
1092        l.event_handler().unwrap();
1093        // limiter should now be unblocked
1094        assert!(!l.is_blocked());
1095        // try and succeed on another 100 ops this time
1096        assert!(l.consume(100, TokenType::Ops));
1097        // try and succeed on another 100 bytes this time
1098        assert!(l.consume(100, TokenType::Bytes));
1099    }
1100
1101    #[test]
1102    fn test_rate_limiter_overconsumption() {
1103        // initialize the rate limiter
1104        let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
1105        // try to consume 2.5x the bucket size
1106        // we are "borrowing" 1.5x the bucket size in tokens since
1107        // the bucket is full
1108        assert!(l.consume(2500, TokenType::Bytes));
1109
1110        // check that even after a whole second passes, the rate limiter
1111        // is still blocked
1112        thread::sleep(Duration::from_millis(1000));
1113        l.event_handler().unwrap_err();
1114        assert!(l.is_blocked());
1115
1116        // after 1.5x the replenish time has passed, the rate limiter
1117        // is available again
1118        thread::sleep(Duration::from_millis(500));
1119        l.event_handler().unwrap();
1120        assert!(!l.is_blocked());
1121
1122        // reset the rate limiter
1123        let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
1124        // try to consume 1.5x the bucket size
1125        // we are "borrowing" 1.5x the bucket size in tokens since
1126        // the bucket is full, should arm the timer to 0.5x replenish
1127        // time, which is 500 ms
1128        assert!(l.consume(1500, TokenType::Bytes));
1129
1130        // check that after more than the minimum refill time,
1131        // the rate limiter is still blocked
1132        thread::sleep(Duration::from_millis(200));
1133        l.event_handler().unwrap_err();
1134        assert!(l.is_blocked());
1135
1136        // try to consume some tokens, which should fail as the timer
1137        // is still active
1138        assert!(!l.consume(100, TokenType::Bytes));
1139        l.event_handler().unwrap_err();
1140        assert!(l.is_blocked());
1141
1142        // check that after the minimum refill time, the timer was not
1143        // overwritten and the rate limiter is still blocked from the
1144        // borrowing we performed earlier
1145        thread::sleep(Duration::from_millis(100));
1146        l.event_handler().unwrap_err();
1147        assert!(l.is_blocked());
1148        assert!(!l.consume(100, TokenType::Bytes));
1149
1150        // after waiting out the full duration, rate limiter should be
1151        // availale again
1152        thread::sleep(Duration::from_millis(200));
1153        l.event_handler().unwrap();
1154        assert!(!l.is_blocked());
1155        assert!(l.consume(100, TokenType::Bytes));
1156    }
1157
1158    #[test]
1159    fn test_update_buckets() {
1160        let mut x = RateLimiter::new(1000, 2000, 1000, 10, 20, 1000).unwrap();
1161
1162        let initial_bw = x.bandwidth.clone();
1163        let initial_ops = x.ops.clone();
1164
1165        x.update_buckets(BucketUpdate::None, BucketUpdate::None);
1166        assert_eq!(x.bandwidth, initial_bw);
1167        assert_eq!(x.ops, initial_ops);
1168
1169        let new_bw = TokenBucket::new(123, 0, 57).unwrap();
1170        let new_ops = TokenBucket::new(321, 12346, 89).unwrap();
1171        x.update_buckets(
1172            BucketUpdate::Update(new_bw.clone()),
1173            BucketUpdate::Update(new_ops.clone()),
1174        );
1175
1176        // We have manually adjust the last_update field, because it changes when update_buckets()
1177        // constructs new buckets (and thus gets a different value for last_update). We do this so
1178        // it makes sense to test the following assertions.
1179        x.bandwidth.as_mut().unwrap().last_update = new_bw.last_update;
1180        x.ops.as_mut().unwrap().last_update = new_ops.last_update;
1181
1182        assert_eq!(x.bandwidth, Some(new_bw));
1183        assert_eq!(x.ops, Some(new_ops));
1184
1185        x.update_buckets(BucketUpdate::Disabled, BucketUpdate::Disabled);
1186        assert_eq!(x.bandwidth, None);
1187        assert_eq!(x.ops, None);
1188    }
1189
1190    #[test]
1191    fn test_rate_limiter_debug() {
1192        let l = RateLimiter::new(1, 2, 3, 4, 5, 6).unwrap();
1193        assert_eq!(
1194            format!("{:?}", l),
1195            format!(
1196                "RateLimiter {{ bandwidth: {:?}, ops: {:?} }}",
1197                l.bandwidth(),
1198                l.ops()
1199            ),
1200        );
1201    }
1202}