shadow_rs/network/relay/
token_bucket.rs

1use shadow_shim_helper_rs::emulated_time::EmulatedTime;
2use shadow_shim_helper_rs::simulation_time::SimulationTime;
3
4use crate::core::worker::Worker;
5
6pub struct TokenBucket {
7    capacity: u64,
8    balance: u64,
9    refill_increment: u64,
10    refill_interval: SimulationTime,
11    last_refill: EmulatedTime,
12}
13
14impl TokenBucket {
15    /// Creates a new token bucket rate limiter with an initial balance set to
16    /// `capacity`. The capacity enables burstiness, while the long term rate is
17    /// defined by `refill_increment` tokens being periodically added to the
18    /// bucket every `refill_interval` duration. Returns None if any of the args
19    /// are non-positive.
20    pub fn new(
21        capacity: u64,
22        refill_increment: u64,
23        refill_interval: SimulationTime,
24    ) -> Option<TokenBucket> {
25        // Since we start at full capacity, starting with a last refill time of
26        // 0 is inconsequential.
27        TokenBucket::new_inner(
28            capacity,
29            refill_increment,
30            refill_interval,
31            EmulatedTime::SIMULATION_START,
32        )
33    }
34
35    /// Implements the functionality of `new()` allowing the caller to set the
36    /// last refill time. Useful for testing.
37    fn new_inner(
38        capacity: u64,
39        refill_increment: u64,
40        refill_interval: SimulationTime,
41        last_refill: EmulatedTime,
42    ) -> Option<TokenBucket> {
43        if capacity > 0 && refill_increment > 0 && !refill_interval.is_zero() {
44            log::trace!(
45                "Initializing token bucket with capacity {capacity}, will refill {refill_increment} tokens every {refill_interval:?}"
46            );
47            Some(TokenBucket {
48                capacity,
49                balance: capacity,
50                refill_increment,
51                refill_interval,
52                last_refill,
53            })
54        } else {
55            None
56        }
57    }
58
59    /// Remove `decrement` tokens from the bucket if and only if the bucket
60    /// contains at least `decrement` tokens. Returns the updated token balance
61    /// on success, or the duration until the next refill event after which we
62    /// would have enough tokens to allow the decrement to conform on error
63    /// (returned durations always align with this `TokenBucket`'s discrete
64    /// refill interval boundaries). Passing a 0 `decrement` always succeeds.
65    pub fn comforming_remove(&mut self, decrement: u64) -> Result<u64, SimulationTime> {
66        let now = Worker::current_time().unwrap();
67        self.conforming_remove_inner(decrement, &now)
68    }
69
70    /// Implements the functionality of `comforming_remove()` without calling into the
71    /// `Worker` module. Useful for testing.
72    fn conforming_remove_inner(
73        &mut self,
74        decrement: u64,
75        now: &EmulatedTime,
76    ) -> Result<u64, SimulationTime> {
77        let next_refill_span = self.lazy_refill(now);
78        self.balance = self
79            .balance
80            .checked_sub(decrement)
81            .ok_or_else(|| self.compute_conforming_duration(decrement, next_refill_span))?;
82        Ok(self.balance)
83    }
84
85    /// Computes the duration required to refill enough tokens such that our
86    /// balance can be decremented by the given `decrement`. Returned durations
87    /// always align with this `TokenBucket`'s discrete refill interval
88    /// boundaries, as configured by its refill interval. `next_refill_span` is
89    /// the duration until the next refill, which may be less than a full refill
90    /// interval.
91    fn compute_conforming_duration(
92        &self,
93        decrement: u64,
94        next_refill_span: SimulationTime,
95    ) -> SimulationTime {
96        let required_token_increment = decrement.saturating_sub(self.balance);
97
98        let num_required_refills = {
99            // Same as `required_token_increment.div_ceil(self.refill_increment);`
100            let num_refills = required_token_increment / self.refill_increment;
101            let remainder = required_token_increment % self.refill_increment;
102            if remainder > 0 {
103                num_refills + 1
104            } else {
105                num_refills
106            }
107        };
108
109        match num_required_refills {
110            0 => SimulationTime::ZERO,
111            1 => next_refill_span,
112            _ => next_refill_span.saturating_add(
113                self.refill_interval
114                    .saturating_mul(num_required_refills.checked_sub(1).unwrap()),
115            ),
116        }
117    }
118
119    /// Simulates a fixed refill schedule following the bucket's configured
120    /// refill interval. This function will lazily apply refills that may have
121    /// occurred in the past but were not applied yet because the token bucket
122    /// was not in use. No refills will occur if called multiple times within
123    /// the same refill interval. Returns the duration to the next refill event.
124    fn lazy_refill(&mut self, now: &EmulatedTime) -> SimulationTime {
125        let mut span = now.duration_since(&self.last_refill);
126
127        if span >= self.refill_interval {
128            // Apply refills for the scheduled refill events that have passed.
129            let num_refills = span
130                .as_nanos()
131                .checked_div(self.refill_interval.as_nanos())
132                .unwrap();
133            let num_tokens = self
134                .refill_increment
135                .saturating_mul(num_refills.try_into().unwrap());
136            debug_assert!(num_tokens > 0);
137
138            self.balance = self
139                .balance
140                .saturating_add(num_tokens)
141                .clamp(0, self.capacity);
142
143            // Update to the most recent refill event time.
144            let inc = self
145                .refill_interval
146                .saturating_mul(num_refills.try_into().unwrap());
147            self.last_refill = self.last_refill.saturating_add(inc);
148
149            span = now.duration_since(&self.last_refill);
150        }
151
152        debug_assert!(span < self.refill_interval);
153        self.refill_interval - span
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::network::tests::mock_time_millis;
161
162    #[test]
163    fn test_new_invalid_args() {
164        let now = mock_time_millis(1000);
165        assert!(TokenBucket::new_inner(0, 1, SimulationTime::from_nanos(1), now).is_none());
166        assert!(TokenBucket::new_inner(1, 0, SimulationTime::from_nanos(1), now).is_none());
167        assert!(TokenBucket::new_inner(1, 1, SimulationTime::ZERO, now).is_none());
168    }
169
170    #[test]
171    fn test_new_valid_args() {
172        let now = mock_time_millis(1000);
173        assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_nanos(1), now).is_some());
174        assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_millis(1), now).is_some());
175        assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_secs(1), now).is_some());
176
177        let tb = TokenBucket::new_inner(54321, 12345, SimulationTime::from_secs(1), now).unwrap();
178        assert_eq!(tb.capacity, 54321);
179        assert_eq!(tb.refill_increment, 12345);
180        assert_eq!(tb.refill_interval, SimulationTime::from_secs(1));
181    }
182
183    #[test]
184    fn test_refill_after_one_interval() {
185        let interval = SimulationTime::from_millis(10);
186        let capacity = 100;
187        let increment = 10;
188        let now = mock_time_millis(1000);
189
190        let mut tb = TokenBucket::new_inner(capacity, increment, interval, now).unwrap();
191        assert_eq!(tb.balance, capacity);
192
193        // Remove all tokens
194        assert!(tb.conforming_remove_inner(capacity, &now).is_ok());
195        assert_eq!(tb.balance, 0);
196
197        for i in 1..=(capacity / increment) {
198            // One more interval of time passes
199            let later = now + interval.saturating_mul(i);
200            // Should cause an increment to the balance
201            let result = tb.conforming_remove_inner(0, &later);
202            assert!(result.is_ok());
203            assert_eq!(result.unwrap(), tb.balance);
204            assert_eq!(tb.balance, increment.saturating_mul(i));
205        }
206    }
207
208    #[test]
209    fn test_refill_after_multiple_intervals() {
210        let now = mock_time_millis(1000);
211        let mut tb = TokenBucket::new_inner(100, 10, SimulationTime::from_millis(10), now).unwrap();
212
213        // Remove all tokens
214        assert!(tb.conforming_remove_inner(100, &now).is_ok());
215        assert_eq!(tb.balance, 0);
216
217        // 5 Refill intervals have passed
218        let later = now + SimulationTime::from_millis(50);
219
220        let result = tb.conforming_remove_inner(0, &later);
221        assert!(result.is_ok());
222        assert_eq!(result.unwrap(), tb.balance);
223        assert_eq!(tb.balance, 50);
224    }
225
226    #[test]
227    fn test_capacity_limit() {
228        let now = mock_time_millis(1000);
229        let mut tb = TokenBucket::new_inner(100, 10, SimulationTime::from_millis(10), now).unwrap();
230
231        // Remove all tokens
232        assert!(tb.conforming_remove_inner(100, &now).is_ok());
233        assert_eq!(tb.balance, 0);
234
235        // Far into the future
236        let later = now + SimulationTime::from_secs(60);
237
238        // Should not exceed capacity
239        let result = tb.conforming_remove_inner(0, &later);
240        assert!(result.is_ok());
241        assert_eq!(result.unwrap(), tb.balance);
242        assert_eq!(tb.balance, 100);
243    }
244
245    #[test]
246    fn test_remove_error() {
247        let now = mock_time_millis(1000);
248        let mut tb =
249            TokenBucket::new_inner(100, 10, SimulationTime::from_millis(125), now).unwrap();
250
251        // Clear the bucket
252        let result = tb.conforming_remove_inner(100, &now);
253        assert!(result.is_ok());
254        assert_eq!(result.unwrap(), 0);
255
256        // This many tokens are not available
257        let result = tb.conforming_remove_inner(50, &now);
258        assert!(result.is_err());
259
260        // Refilling 10 tokens every 125 millis will require 5 refills
261        let dur_until_conforming = SimulationTime::from_millis(125 * 5);
262        assert_eq!(result.unwrap_err(), dur_until_conforming);
263
264        // Moving time forward is still an error
265        let inc = 10;
266        let now = mock_time_millis(1000 + inc);
267        let result = tb.conforming_remove_inner(50, &now);
268        assert!(result.is_err());
269
270        // We still need 5 refills, but we are 10 millis closer until it conforms
271        let dur_until_conforming = SimulationTime::from_millis(125 * 5 - inc);
272        assert_eq!(result.unwrap_err(), dur_until_conforming);
273    }
274}