1use shadow_shim_helper_rs::emulated_time::EmulatedTime;
2use shadow_shim_helper_rs::simulation_time::SimulationTime;
3
4use crate::core::worker::Worker;
5
6pub struct TokenBucket {
7 capacity: u64,
8 balance: u64,
9 refill_increment: u64,
10 refill_interval: SimulationTime,
11 last_refill: EmulatedTime,
12}
13
14impl TokenBucket {
15 pub fn new(
21 capacity: u64,
22 refill_increment: u64,
23 refill_interval: SimulationTime,
24 ) -> Option<TokenBucket> {
25 TokenBucket::new_inner(
28 capacity,
29 refill_increment,
30 refill_interval,
31 EmulatedTime::SIMULATION_START,
32 )
33 }
34
35 fn new_inner(
38 capacity: u64,
39 refill_increment: u64,
40 refill_interval: SimulationTime,
41 last_refill: EmulatedTime,
42 ) -> Option<TokenBucket> {
43 if capacity > 0 && refill_increment > 0 && !refill_interval.is_zero() {
44 log::trace!(
45 "Initializing token bucket with capacity {}, will refill {} tokens every {:?}",
46 capacity,
47 refill_increment,
48 refill_interval
49 );
50 Some(TokenBucket {
51 capacity,
52 balance: capacity,
53 refill_increment,
54 refill_interval,
55 last_refill,
56 })
57 } else {
58 None
59 }
60 }
61
62 pub fn comforming_remove(&mut self, decrement: u64) -> Result<u64, SimulationTime> {
69 let now = Worker::current_time().unwrap();
70 self.conforming_remove_inner(decrement, &now)
71 }
72
73 fn conforming_remove_inner(
76 &mut self,
77 decrement: u64,
78 now: &EmulatedTime,
79 ) -> Result<u64, SimulationTime> {
80 let next_refill_span = self.lazy_refill(now);
81 self.balance = self
82 .balance
83 .checked_sub(decrement)
84 .ok_or_else(|| self.compute_conforming_duration(decrement, next_refill_span))?;
85 Ok(self.balance)
86 }
87
88 fn compute_conforming_duration(
95 &self,
96 decrement: u64,
97 next_refill_span: SimulationTime,
98 ) -> SimulationTime {
99 let required_token_increment = decrement.saturating_sub(self.balance);
100
101 let num_required_refills = {
102 let num_refills = required_token_increment / self.refill_increment;
104 let remainder = required_token_increment % self.refill_increment;
105 if remainder > 0 {
106 num_refills + 1
107 } else {
108 num_refills
109 }
110 };
111
112 match num_required_refills {
113 0 => SimulationTime::ZERO,
114 1 => next_refill_span,
115 _ => next_refill_span.saturating_add(
116 self.refill_interval
117 .saturating_mul(num_required_refills.checked_sub(1).unwrap()),
118 ),
119 }
120 }
121
122 fn lazy_refill(&mut self, now: &EmulatedTime) -> SimulationTime {
128 let mut span = now.duration_since(&self.last_refill);
129
130 if span >= self.refill_interval {
131 let num_refills = span
133 .as_nanos()
134 .checked_div(self.refill_interval.as_nanos())
135 .unwrap();
136 let num_tokens = self
137 .refill_increment
138 .saturating_mul(num_refills.try_into().unwrap());
139 debug_assert!(num_tokens > 0);
140
141 self.balance = self
142 .balance
143 .saturating_add(num_tokens)
144 .clamp(0, self.capacity);
145
146 let inc = self
148 .refill_interval
149 .saturating_mul(num_refills.try_into().unwrap());
150 self.last_refill = self.last_refill.saturating_add(inc);
151
152 span = now.duration_since(&self.last_refill);
153 }
154
155 debug_assert!(span < self.refill_interval);
156 self.refill_interval - span
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::network::tests::mock_time_millis;
164
165 #[test]
166 fn test_new_invalid_args() {
167 let now = mock_time_millis(1000);
168 assert!(TokenBucket::new_inner(0, 1, SimulationTime::from_nanos(1), now).is_none());
169 assert!(TokenBucket::new_inner(1, 0, SimulationTime::from_nanos(1), now).is_none());
170 assert!(TokenBucket::new_inner(1, 1, SimulationTime::ZERO, now).is_none());
171 }
172
173 #[test]
174 fn test_new_valid_args() {
175 let now = mock_time_millis(1000);
176 assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_nanos(1), now).is_some());
177 assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_millis(1), now).is_some());
178 assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_secs(1), now).is_some());
179
180 let tb = TokenBucket::new_inner(54321, 12345, SimulationTime::from_secs(1), now).unwrap();
181 assert_eq!(tb.capacity, 54321);
182 assert_eq!(tb.refill_increment, 12345);
183 assert_eq!(tb.refill_interval, SimulationTime::from_secs(1));
184 }
185
186 #[test]
187 fn test_refill_after_one_interval() {
188 let interval = SimulationTime::from_millis(10);
189 let capacity = 100;
190 let increment = 10;
191 let now = mock_time_millis(1000);
192
193 let mut tb = TokenBucket::new_inner(capacity, increment, interval, now).unwrap();
194 assert_eq!(tb.balance, capacity);
195
196 assert!(tb.conforming_remove_inner(capacity, &now).is_ok());
198 assert_eq!(tb.balance, 0);
199
200 for i in 1..=(capacity / increment) {
201 let later = now + interval.saturating_mul(i);
203 let result = tb.conforming_remove_inner(0, &later);
205 assert!(result.is_ok());
206 assert_eq!(result.unwrap(), tb.balance);
207 assert_eq!(tb.balance, increment.saturating_mul(i));
208 }
209 }
210
211 #[test]
212 fn test_refill_after_multiple_intervals() {
213 let now = mock_time_millis(1000);
214 let mut tb = TokenBucket::new_inner(100, 10, SimulationTime::from_millis(10), now).unwrap();
215
216 assert!(tb.conforming_remove_inner(100, &now).is_ok());
218 assert_eq!(tb.balance, 0);
219
220 let later = now + SimulationTime::from_millis(50);
222
223 let result = tb.conforming_remove_inner(0, &later);
224 assert!(result.is_ok());
225 assert_eq!(result.unwrap(), tb.balance);
226 assert_eq!(tb.balance, 50);
227 }
228
229 #[test]
230 fn test_capacity_limit() {
231 let now = mock_time_millis(1000);
232 let mut tb = TokenBucket::new_inner(100, 10, SimulationTime::from_millis(10), now).unwrap();
233
234 assert!(tb.conforming_remove_inner(100, &now).is_ok());
236 assert_eq!(tb.balance, 0);
237
238 let later = now + SimulationTime::from_secs(60);
240
241 let result = tb.conforming_remove_inner(0, &later);
243 assert!(result.is_ok());
244 assert_eq!(result.unwrap(), tb.balance);
245 assert_eq!(tb.balance, 100);
246 }
247
248 #[test]
249 fn test_remove_error() {
250 let now = mock_time_millis(1000);
251 let mut tb =
252 TokenBucket::new_inner(100, 10, SimulationTime::from_millis(125), now).unwrap();
253
254 let result = tb.conforming_remove_inner(100, &now);
256 assert!(result.is_ok());
257 assert_eq!(result.unwrap(), 0);
258
259 let result = tb.conforming_remove_inner(50, &now);
261 assert!(result.is_err());
262
263 let dur_until_conforming = SimulationTime::from_millis(125 * 5);
265 assert_eq!(result.unwrap_err(), dur_until_conforming);
266
267 let inc = 10;
269 let now = mock_time_millis(1000 + inc);
270 let result = tb.conforming_remove_inner(50, &now);
271 assert!(result.is_err());
272
273 let dur_until_conforming = SimulationTime::from_millis(125 * 5 - inc);
275 assert_eq!(result.unwrap_err(), dur_until_conforming);
276 }
277}