shadow_rs/network/relay/
token_bucket.rs

1use shadow_shim_helper_rs::emulated_time::EmulatedTime;
2use shadow_shim_helper_rs::simulation_time::SimulationTime;
3
4use crate::core::worker::Worker;
5
6pub struct TokenBucket {
7    capacity: u64,
8    balance: u64,
9    refill_increment: u64,
10    refill_interval: SimulationTime,
11    last_refill: EmulatedTime,
12}
13
14impl TokenBucket {
15    /// Creates a new token bucket rate limiter with an initial balance set to
16    /// `capacity`. The capacity enables burstiness, while the long term rate is
17    /// defined by `refill_increment` tokens being periodically added to the
18    /// bucket every `refill_interval` duration. Returns None if any of the args
19    /// are non-positive.
20    pub fn new(
21        capacity: u64,
22        refill_increment: u64,
23        refill_interval: SimulationTime,
24    ) -> Option<TokenBucket> {
25        // Since we start at full capacity, starting with a last refill time of
26        // 0 is inconsequential.
27        TokenBucket::new_inner(
28            capacity,
29            refill_increment,
30            refill_interval,
31            EmulatedTime::SIMULATION_START,
32        )
33    }
34
35    /// Implements the functionality of `new()` allowing the caller to set the
36    /// last refill time. Useful for testing.
37    fn new_inner(
38        capacity: u64,
39        refill_increment: u64,
40        refill_interval: SimulationTime,
41        last_refill: EmulatedTime,
42    ) -> Option<TokenBucket> {
43        if capacity > 0 && refill_increment > 0 && !refill_interval.is_zero() {
44            log::trace!(
45                "Initializing token bucket with capacity {}, will refill {} tokens every {:?}",
46                capacity,
47                refill_increment,
48                refill_interval
49            );
50            Some(TokenBucket {
51                capacity,
52                balance: capacity,
53                refill_increment,
54                refill_interval,
55                last_refill,
56            })
57        } else {
58            None
59        }
60    }
61
62    /// Remove `decrement` tokens from the bucket if and only if the bucket
63    /// contains at least `decrement` tokens. Returns the updated token balance
64    /// on success, or the duration until the next refill event after which we
65    /// would have enough tokens to allow the decrement to conform on error
66    /// (returned durations always align with this `TokenBucket`'s discrete
67    /// refill interval boundaries). Passing a 0 `decrement` always succeeds.
68    pub fn comforming_remove(&mut self, decrement: u64) -> Result<u64, SimulationTime> {
69        let now = Worker::current_time().unwrap();
70        self.conforming_remove_inner(decrement, &now)
71    }
72
73    /// Implements the functionality of `comforming_remove()` without calling into the
74    /// `Worker` module. Useful for testing.
75    fn conforming_remove_inner(
76        &mut self,
77        decrement: u64,
78        now: &EmulatedTime,
79    ) -> Result<u64, SimulationTime> {
80        let next_refill_span = self.lazy_refill(now);
81        self.balance = self
82            .balance
83            .checked_sub(decrement)
84            .ok_or_else(|| self.compute_conforming_duration(decrement, next_refill_span))?;
85        Ok(self.balance)
86    }
87
88    /// Computes the duration required to refill enough tokens such that our
89    /// balance can be decremented by the given `decrement`. Returned durations
90    /// always align with this `TokenBucket`'s discrete refill interval
91    /// boundaries, as configured by its refill interval. `next_refill_span` is
92    /// the duration until the next refill, which may be less than a full refill
93    /// interval.
94    fn compute_conforming_duration(
95        &self,
96        decrement: u64,
97        next_refill_span: SimulationTime,
98    ) -> SimulationTime {
99        let required_token_increment = decrement.saturating_sub(self.balance);
100
101        let num_required_refills = {
102            // Same as `required_token_increment.div_ceil(self.refill_increment);`
103            let num_refills = required_token_increment / self.refill_increment;
104            let remainder = required_token_increment % self.refill_increment;
105            if remainder > 0 {
106                num_refills + 1
107            } else {
108                num_refills
109            }
110        };
111
112        match num_required_refills {
113            0 => SimulationTime::ZERO,
114            1 => next_refill_span,
115            _ => next_refill_span.saturating_add(
116                self.refill_interval
117                    .saturating_mul(num_required_refills.checked_sub(1).unwrap()),
118            ),
119        }
120    }
121
122    /// Simulates a fixed refill schedule following the bucket's configured
123    /// refill interval. This function will lazily apply refills that may have
124    /// occurred in the past but were not applied yet because the token bucket
125    /// was not in use. No refills will occur if called multiple times within
126    /// the same refill interval. Returns the duration to the next refill event.
127    fn lazy_refill(&mut self, now: &EmulatedTime) -> SimulationTime {
128        let mut span = now.duration_since(&self.last_refill);
129
130        if span >= self.refill_interval {
131            // Apply refills for the scheduled refill events that have passed.
132            let num_refills = span
133                .as_nanos()
134                .checked_div(self.refill_interval.as_nanos())
135                .unwrap();
136            let num_tokens = self
137                .refill_increment
138                .saturating_mul(num_refills.try_into().unwrap());
139            debug_assert!(num_tokens > 0);
140
141            self.balance = self
142                .balance
143                .saturating_add(num_tokens)
144                .clamp(0, self.capacity);
145
146            // Update to the most recent refill event time.
147            let inc = self
148                .refill_interval
149                .saturating_mul(num_refills.try_into().unwrap());
150            self.last_refill = self.last_refill.saturating_add(inc);
151
152            span = now.duration_since(&self.last_refill);
153        }
154
155        debug_assert!(span < self.refill_interval);
156        self.refill_interval - span
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::network::tests::mock_time_millis;
164
165    #[test]
166    fn test_new_invalid_args() {
167        let now = mock_time_millis(1000);
168        assert!(TokenBucket::new_inner(0, 1, SimulationTime::from_nanos(1), now).is_none());
169        assert!(TokenBucket::new_inner(1, 0, SimulationTime::from_nanos(1), now).is_none());
170        assert!(TokenBucket::new_inner(1, 1, SimulationTime::ZERO, now).is_none());
171    }
172
173    #[test]
174    fn test_new_valid_args() {
175        let now = mock_time_millis(1000);
176        assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_nanos(1), now).is_some());
177        assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_millis(1), now).is_some());
178        assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_secs(1), now).is_some());
179
180        let tb = TokenBucket::new_inner(54321, 12345, SimulationTime::from_secs(1), now).unwrap();
181        assert_eq!(tb.capacity, 54321);
182        assert_eq!(tb.refill_increment, 12345);
183        assert_eq!(tb.refill_interval, SimulationTime::from_secs(1));
184    }
185
186    #[test]
187    fn test_refill_after_one_interval() {
188        let interval = SimulationTime::from_millis(10);
189        let capacity = 100;
190        let increment = 10;
191        let now = mock_time_millis(1000);
192
193        let mut tb = TokenBucket::new_inner(capacity, increment, interval, now).unwrap();
194        assert_eq!(tb.balance, capacity);
195
196        // Remove all tokens
197        assert!(tb.conforming_remove_inner(capacity, &now).is_ok());
198        assert_eq!(tb.balance, 0);
199
200        for i in 1..=(capacity / increment) {
201            // One more interval of time passes
202            let later = now + interval.saturating_mul(i);
203            // Should cause an increment to the balance
204            let result = tb.conforming_remove_inner(0, &later);
205            assert!(result.is_ok());
206            assert_eq!(result.unwrap(), tb.balance);
207            assert_eq!(tb.balance, increment.saturating_mul(i));
208        }
209    }
210
211    #[test]
212    fn test_refill_after_multiple_intervals() {
213        let now = mock_time_millis(1000);
214        let mut tb = TokenBucket::new_inner(100, 10, SimulationTime::from_millis(10), now).unwrap();
215
216        // Remove all tokens
217        assert!(tb.conforming_remove_inner(100, &now).is_ok());
218        assert_eq!(tb.balance, 0);
219
220        // 5 Refill intervals have passed
221        let later = now + SimulationTime::from_millis(50);
222
223        let result = tb.conforming_remove_inner(0, &later);
224        assert!(result.is_ok());
225        assert_eq!(result.unwrap(), tb.balance);
226        assert_eq!(tb.balance, 50);
227    }
228
229    #[test]
230    fn test_capacity_limit() {
231        let now = mock_time_millis(1000);
232        let mut tb = TokenBucket::new_inner(100, 10, SimulationTime::from_millis(10), now).unwrap();
233
234        // Remove all tokens
235        assert!(tb.conforming_remove_inner(100, &now).is_ok());
236        assert_eq!(tb.balance, 0);
237
238        // Far into the future
239        let later = now + SimulationTime::from_secs(60);
240
241        // Should not exceed capacity
242        let result = tb.conforming_remove_inner(0, &later);
243        assert!(result.is_ok());
244        assert_eq!(result.unwrap(), tb.balance);
245        assert_eq!(tb.balance, 100);
246    }
247
248    #[test]
249    fn test_remove_error() {
250        let now = mock_time_millis(1000);
251        let mut tb =
252            TokenBucket::new_inner(100, 10, SimulationTime::from_millis(125), now).unwrap();
253
254        // Clear the bucket
255        let result = tb.conforming_remove_inner(100, &now);
256        assert!(result.is_ok());
257        assert_eq!(result.unwrap(), 0);
258
259        // This many tokens are not available
260        let result = tb.conforming_remove_inner(50, &now);
261        assert!(result.is_err());
262
263        // Refilling 10 tokens every 125 millis will require 5 refills
264        let dur_until_conforming = SimulationTime::from_millis(125 * 5);
265        assert_eq!(result.unwrap_err(), dur_until_conforming);
266
267        // Moving time forward is still an error
268        let inc = 10;
269        let now = mock_time_millis(1000 + inc);
270        let result = tb.conforming_remove_inner(50, &now);
271        assert!(result.is_err());
272
273        // We still need 5 refills, but we are 10 millis closer until it conforms
274        let dur_until_conforming = SimulationTime::from_millis(125 * 5 - inc);
275        assert_eq!(result.unwrap_err(), dur_until_conforming);
276    }
277}