use shadow_shim_helper_rs::emulated_time::EmulatedTime;
use shadow_shim_helper_rs::simulation_time::SimulationTime;
use crate::core::worker::Worker;
pub struct TokenBucket {
capacity: u64,
balance: u64,
refill_increment: u64,
refill_interval: SimulationTime,
last_refill: EmulatedTime,
}
impl TokenBucket {
pub fn new(
capacity: u64,
refill_increment: u64,
refill_interval: SimulationTime,
) -> Option<TokenBucket> {
TokenBucket::new_inner(
capacity,
refill_increment,
refill_interval,
EmulatedTime::SIMULATION_START,
)
}
fn new_inner(
capacity: u64,
refill_increment: u64,
refill_interval: SimulationTime,
last_refill: EmulatedTime,
) -> Option<TokenBucket> {
if capacity > 0 && refill_increment > 0 && !refill_interval.is_zero() {
log::trace!(
"Initializing token bucket with capacity {}, will refill {} tokens every {:?}",
capacity,
refill_increment,
refill_interval
);
Some(TokenBucket {
capacity,
balance: capacity,
refill_increment,
refill_interval,
last_refill,
})
} else {
None
}
}
pub fn comforming_remove(&mut self, decrement: u64) -> Result<u64, SimulationTime> {
let now = Worker::current_time().unwrap();
self.conforming_remove_inner(decrement, &now)
}
fn conforming_remove_inner(
&mut self,
decrement: u64,
now: &EmulatedTime,
) -> Result<u64, SimulationTime> {
let next_refill_span = self.lazy_refill(now);
self.balance = self
.balance
.checked_sub(decrement)
.ok_or_else(|| self.compute_conforming_duration(decrement, next_refill_span))?;
Ok(self.balance)
}
fn compute_conforming_duration(
&self,
decrement: u64,
next_refill_span: SimulationTime,
) -> SimulationTime {
let required_token_increment = decrement.saturating_sub(self.balance);
let num_required_refills = {
let num_refills = required_token_increment / self.refill_increment;
let remainder = required_token_increment % self.refill_increment;
if remainder > 0 {
num_refills + 1
} else {
num_refills
}
};
match num_required_refills {
0 => SimulationTime::ZERO,
1 => next_refill_span,
_ => next_refill_span.saturating_add(
self.refill_interval
.saturating_mul(num_required_refills.checked_sub(1).unwrap()),
),
}
}
fn lazy_refill(&mut self, now: &EmulatedTime) -> SimulationTime {
let mut span = now.duration_since(&self.last_refill);
if span >= self.refill_interval {
let num_refills = span
.as_nanos()
.checked_div(self.refill_interval.as_nanos())
.unwrap();
let num_tokens = self
.refill_increment
.saturating_mul(num_refills.try_into().unwrap());
debug_assert!(num_tokens > 0);
self.balance = self
.balance
.saturating_add(num_tokens)
.clamp(0, self.capacity);
let inc = self
.refill_interval
.saturating_mul(num_refills.try_into().unwrap());
self.last_refill = self.last_refill.saturating_add(inc);
span = now.duration_since(&self.last_refill);
}
debug_assert!(span < self.refill_interval);
self.refill_interval - span
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::network::tests::mock_time_millis;
#[test]
fn test_new_invalid_args() {
let now = mock_time_millis(1000);
assert!(TokenBucket::new_inner(0, 1, SimulationTime::from_nanos(1), now).is_none());
assert!(TokenBucket::new_inner(1, 0, SimulationTime::from_nanos(1), now).is_none());
assert!(TokenBucket::new_inner(1, 1, SimulationTime::ZERO, now).is_none());
}
#[test]
fn test_new_valid_args() {
let now = mock_time_millis(1000);
assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_nanos(1), now).is_some());
assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_millis(1), now).is_some());
assert!(TokenBucket::new_inner(1, 1, SimulationTime::from_secs(1), now).is_some());
let tb = TokenBucket::new_inner(54321, 12345, SimulationTime::from_secs(1), now).unwrap();
assert_eq!(tb.capacity, 54321);
assert_eq!(tb.refill_increment, 12345);
assert_eq!(tb.refill_interval, SimulationTime::from_secs(1));
}
#[test]
fn test_refill_after_one_interval() {
let interval = SimulationTime::from_millis(10);
let capacity = 100;
let increment = 10;
let now = mock_time_millis(1000);
let mut tb = TokenBucket::new_inner(capacity, increment, interval, now).unwrap();
assert_eq!(tb.balance, capacity);
assert!(tb.conforming_remove_inner(capacity, &now).is_ok());
assert_eq!(tb.balance, 0);
for i in 1..=(capacity / increment) {
let later = now + interval.saturating_mul(i);
let result = tb.conforming_remove_inner(0, &later);
assert!(result.is_ok());
assert_eq!(result.unwrap(), tb.balance);
assert_eq!(tb.balance, increment.saturating_mul(i));
}
}
#[test]
fn test_refill_after_multiple_intervals() {
let now = mock_time_millis(1000);
let mut tb = TokenBucket::new_inner(100, 10, SimulationTime::from_millis(10), now).unwrap();
assert!(tb.conforming_remove_inner(100, &now).is_ok());
assert_eq!(tb.balance, 0);
let later = now + SimulationTime::from_millis(50);
let result = tb.conforming_remove_inner(0, &later);
assert!(result.is_ok());
assert_eq!(result.unwrap(), tb.balance);
assert_eq!(tb.balance, 50);
}
#[test]
fn test_capacity_limit() {
let now = mock_time_millis(1000);
let mut tb = TokenBucket::new_inner(100, 10, SimulationTime::from_millis(10), now).unwrap();
assert!(tb.conforming_remove_inner(100, &now).is_ok());
assert_eq!(tb.balance, 0);
let later = now + SimulationTime::from_secs(60);
let result = tb.conforming_remove_inner(0, &later);
assert!(result.is_ok());
assert_eq!(result.unwrap(), tb.balance);
assert_eq!(tb.balance, 100);
}
#[test]
fn test_remove_error() {
let now = mock_time_millis(1000);
let mut tb =
TokenBucket::new_inner(100, 10, SimulationTime::from_millis(125), now).unwrap();
let result = tb.conforming_remove_inner(100, &now);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0);
let result = tb.conforming_remove_inner(50, &now);
assert!(result.is_err());
let dur_until_conforming = SimulationTime::from_millis(125 * 5);
assert_eq!(result.unwrap_err(), dur_until_conforming);
let inc = 10;
let now = mock_time_millis(1000 + inc);
let result = tb.conforming_remove_inner(50, &now);
assert!(result.is_err());
let dur_until_conforming = SimulationTime::from_millis(125 * 5 - inc);
assert_eq!(result.unwrap_err(), dur_until_conforming);
}
}