1use shadow_shim_helper_rs::emulated_time::EmulatedTime;
2use shadow_shim_helper_rs::simulation_time::SimulationTime;
3
4use crate::core::worker::Worker;
5
6pub struct TokenBucket {
7 capacity: u64,
8 balance: u64,
9 refill_increment: u64,
10 refill_interval: SimulationTime,
11 last_refill: EmulatedTime,
12}
13
14impl TokenBucket {
15 pub fn new(
21 capacity: u64,
22 refill_increment: u64,
23 refill_interval: SimulationTime,
24 ) -> Option<TokenBucket> {
25 TokenBucket::new_inner(
28 capacity,
29 refill_increment,
30 refill_interval,
31 EmulatedTime::SIMULATION_START,
32 )
33 }
34
35 fn new_inner(
38 capacity: u64,
39 refill_increment: u64,
40 refill_interval: SimulationTime,
41 last_refill: EmulatedTime,
42 ) -> Option<TokenBucket> {
43 if capacity > 0 && refill_increment > 0 && !refill_interval.is_zero() {
44 log::trace!(
45 "Initializing token bucket with capacity {capacity}, will refill {refill_increment} tokens every {refill_interval:?}"
46 );
47 Some(TokenBucket {
48 capacity,
49 balance: capacity,
50 refill_increment,
51 refill_interval,
52 last_refill,
53 })
54 } else {
55 None
56 }
57 }
58
59 pub fn comforming_remove(&mut self, decrement: u64) -> Result<u64, SimulationTime> {
66 let now = Worker::current_time().unwrap();
67 self.conforming_remove_inner(decrement, &now)
68 }
69
70 fn conforming_remove_inner(
73 &mut self,
74 decrement: u64,
75 now: &EmulatedTime,
76 ) -> Result<u64, SimulationTime> {
77 let next_refill_span = self.lazy_refill(now);
78 self.balance = self
79 .balance
80 .checked_sub(decrement)
81 .ok_or_else(|| self.compute_conforming_duration(decrement, next_refill_span))?;
82 Ok(self.balance)
83 }
84
85 fn compute_conforming_duration(
92 &self,
93 decrement: u64,
94 next_refill_span: SimulationTime,
95 ) -> SimulationTime {
96 let required_token_increment = decrement.saturating_sub(self.balance);
97
98 let num_required_refills = {
99 let num_refills = required_token_increment / self.refill_increment;
101 let remainder = required_token_increment % self.refill_increment;
102 if remainder > 0 {
103 num_refills + 1
104 } else {
105 num_refills
106 }
107 };
108
109 match num_required_refills {
110 0 => SimulationTime::ZERO,
111 1 => next_refill_span,
112 _ => next_refill_span.saturating_add(
113 self.refill_interval
114 .saturating_mul(num_required_refills.checked_sub(1).unwrap()),
115 ),
116 }
117 }
118
119 fn lazy_refill(&mut self, now: &EmulatedTime) -> SimulationTime {
125 let mut span = now.duration_since(&self.last_refill);
126
127 if span >= self.refill_interval {
128 let num_refills = span
130 .as_nanos()
131 .checked_div(self.refill_interval.as_nanos())
132 .unwrap();
133 let num_tokens = self
134 .refill_increment
135 .saturating_mul(num_refills.try_into().unwrap());
136 debug_assert!(num_tokens > 0);
137
138 self.balance = self
139 .balance
140 .saturating_add(num_tokens)
141 .clamp(0, self.capacity);
142
143 let inc = self
145 .refill_interval
146 .saturating_mul(num_refills.try_into().unwrap());
147 self.last_refill = self.last_refill.saturating_add(inc);
148
149 span = now.duration_since(&self.last_refill);
150 }
151
152 debug_assert!(span < self.refill_interval);
153 self.refill_interval - span
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::network::tests::mock_time_millis;
161
162 #[test]
163 fn test_new_invalid_args() {
164 let now = mock_time_millis(1000);
165 assert!(TokenBucket::new_inner(0, 1, SimulationTime::from_nanos(1), now).is_none());
166 assert!(TokenBucket::new_inner(1, 0, SimulationTime::from_nanos(1), now).is_none());
167 assert!(TokenBucket::new_inner(1, 1, SimulationTime::ZERO, now).is_none());
168 }
169
170 #[test]
171 fn test_new_valid_args() {
172 let now = mock_time_millis(1000);
173 assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_nanos(1), now).is_some());
174 assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_millis(1), now).is_some());
175 assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_secs(1), now).is_some());
176
177 let tb = TokenBucket::new_inner(54321, 12345, SimulationTime::from_secs(1), now).unwrap();
178 assert_eq!(tb.capacity, 54321);
179 assert_eq!(tb.refill_increment, 12345);
180 assert_eq!(tb.refill_interval, SimulationTime::from_secs(1));
181 }
182
183 #[test]
184 fn test_refill_after_one_interval() {
185 let interval = SimulationTime::from_millis(10);
186 let capacity = 100;
187 let increment = 10;
188 let now = mock_time_millis(1000);
189
190 let mut tb = TokenBucket::new_inner(capacity, increment, interval, now).unwrap();
191 assert_eq!(tb.balance, capacity);
192
193 assert!(tb.conforming_remove_inner(capacity, &now).is_ok());
195 assert_eq!(tb.balance, 0);
196
197 for i in 1..=(capacity / increment) {
198 let later = now + interval.saturating_mul(i);
200 let result = tb.conforming_remove_inner(0, &later);
202 assert!(result.is_ok());
203 assert_eq!(result.unwrap(), tb.balance);
204 assert_eq!(tb.balance, increment.saturating_mul(i));
205 }
206 }
207
208 #[test]
209 fn test_refill_after_multiple_intervals() {
210 let now = mock_time_millis(1000);
211 let mut tb = TokenBucket::new_inner(100, 10, SimulationTime::from_millis(10), now).unwrap();
212
213 assert!(tb.conforming_remove_inner(100, &now).is_ok());
215 assert_eq!(tb.balance, 0);
216
217 let later = now + SimulationTime::from_millis(50);
219
220 let result = tb.conforming_remove_inner(0, &later);
221 assert!(result.is_ok());
222 assert_eq!(result.unwrap(), tb.balance);
223 assert_eq!(tb.balance, 50);
224 }
225
226 #[test]
227 fn test_capacity_limit() {
228 let now = mock_time_millis(1000);
229 let mut tb = TokenBucket::new_inner(100, 10, SimulationTime::from_millis(10), now).unwrap();
230
231 assert!(tb.conforming_remove_inner(100, &now).is_ok());
233 assert_eq!(tb.balance, 0);
234
235 let later = now + SimulationTime::from_secs(60);
237
238 let result = tb.conforming_remove_inner(0, &later);
240 assert!(result.is_ok());
241 assert_eq!(result.unwrap(), tb.balance);
242 assert_eq!(tb.balance, 100);
243 }
244
245 #[test]
246 fn test_remove_error() {
247 let now = mock_time_millis(1000);
248 let mut tb =
249 TokenBucket::new_inner(100, 10, SimulationTime::from_millis(125), now).unwrap();
250
251 let result = tb.conforming_remove_inner(100, &now);
253 assert!(result.is_ok());
254 assert_eq!(result.unwrap(), 0);
255
256 let result = tb.conforming_remove_inner(50, &now);
258 assert!(result.is_err());
259
260 let dur_until_conforming = SimulationTime::from_millis(125 * 5);
262 assert_eq!(result.unwrap_err(), dur_until_conforming);
263
264 let inc = 10;
266 let now = mock_time_millis(1000 + inc);
267 let result = tb.conforming_remove_inner(50, &now);
268 assert!(result.is_err());
269
270 let dur_until_conforming = SimulationTime::from_millis(125 * 5 - inc);
272 assert_eq!(result.unwrap_err(), dur_until_conforming);
273 }
274}