1use std::sync::{Arc, Condvar, Mutex};
23/// A latch counter.
4///
5/// If a counter is cloned, it will inherit the counter's state for the current generation. For
6/// example if a counter is cloned after it has already counted down, then the new counter will also
7/// be treated as if it had already counted down in the current generation. If a counter is cloned
8/// before it has counted down, then the new counter will also need to count down in the current
9/// generation.
10#[derive(Debug)]
11pub struct LatchCounter {
12 inner: Arc<LatchInner>,
13/// An ID for this counter's count-down round.
14generation: usize,
15}
1617/// A latch waiter.
18///
19/// If a waiter is cloned, it will inherit the waiter's state for the current generation. For
20/// example if a waiter is cloned after it has already waited, then the new waiter will also be
21/// treated as if it had already waited in the current generation. If a waiter is cloned before it
22/// has waited, then the new waiter will also need to wait in the current generation.
23#[derive(Debug)]
24pub struct LatchWaiter {
25 inner: Arc<LatchInner>,
26/// An ID for this waiter's count-down round.
27generation: usize,
28}
2930#[derive(Debug)]
31struct LatchInner {
32 lock: Mutex<LatchState>,
33 cond: Condvar,
34}
3536#[derive(Debug)]
37struct LatchState {
38/// The current latch "round".
39generation: usize,
40/// Number of counters remaining.
41counters: usize,
42/// Number of waiters remaining.
43waiters: usize,
44/// Total number of counters.
45total_counters: usize,
46/// Total number of waiters.
47total_waiters: usize,
48}
4950/// Build a latch counter and waiter. The counter and waiter can be cloned to create new counters
51/// and waiters.
52pub fn build_count_down_latch() -> (LatchCounter, LatchWaiter) {
53let inner = Arc::new(LatchInner {
54 lock: Mutex::new(LatchState {
55 generation: 0,
56 counters: 1,
57 waiters: 1,
58 total_counters: 1,
59 total_waiters: 1,
60 }),
61 cond: Condvar::new(),
62 });
6364let counter = LatchCounter {
65 inner: Arc::clone(&inner),
66 generation: 0,
67 };
6869let waiter = LatchWaiter {
70 inner,
71 generation: 0,
72 };
7374 (counter, waiter)
75}
7677impl LatchState {
78pub fn advance_generation(&mut self) {
79debug_assert_eq!(self.counters, 0);
80debug_assert_eq!(self.waiters, 0);
81self.counters = self.total_counters;
82self.waiters = self.total_waiters;
83self.generation = self.generation.wrapping_add(1);
84 }
85}
8687impl LatchCounter {
88/// Decrement the latch count and wake the waiters if the count reaches 0. This must not be
89 /// called more than once per generation (must not be called again until all of the waiters have
90 /// returned from their [`LatchWaiter::wait()`] calls), otherwise it will panic.
91pub fn count_down(&mut self) {
92let counters;
93 {
94let mut lock = self.inner.lock.lock().unwrap();
9596if self.generation != lock.generation {
97let latch_gen = lock.generation;
98 std::mem::drop(lock);
99panic!(
100"Counter generation does not match latch generation ({} != {})",
101self.generation, latch_gen
102 );
103 }
104105 lock.counters = lock.counters.checked_sub(1).unwrap();
106 counters = lock.counters;
107 }
108109// if this is the last counter, notify the waiters
110if counters == 0 {
111self.inner.cond.notify_all();
112 }
113114self.generation = self.generation.wrapping_add(1);
115 }
116}
117118impl LatchWaiter {
119/// Wait for the latch count to reach 0. If the latch count has already reached 0 for the
120 /// current genration, this will return immediately.
121pub fn wait(&mut self) {
122 {
123let lock = self.inner.lock.lock().unwrap();
124125let mut lock = self
126.inner
127 .cond
128// wait until we're in the active generation and all counters have counted down
129.wait_while(lock, |x| self.generation != x.generation || x.counters > 0)
130 .unwrap();
131132 lock.waiters = lock.waiters.checked_sub(1).unwrap();
133134// if this is the last waiter (and we already know that there are no more counters), start
135 // the next generation
136if lock.waiters == 0 {
137 lock.advance_generation();
138 }
139 }
140141self.generation = self.generation.wrapping_add(1);
142 }
143}
144145impl Clone for LatchCounter {
146fn clone(&self) -> Self {
147let mut lock = self.inner.lock.lock().unwrap();
148 lock.total_counters = lock.total_counters.checked_add(1).unwrap();
149150// if we haven't already counted down during the current generation
151if self.generation == lock.generation {
152 lock.counters = lock.counters.checked_add(1).unwrap();
153 }
154155 LatchCounter {
156 inner: Arc::clone(&self.inner),
157 generation: self.generation,
158 }
159 }
160}
161162impl Clone for LatchWaiter {
163fn clone(&self) -> Self {
164let mut lock = self.inner.lock.lock().unwrap();
165 lock.total_waiters = lock.total_waiters.checked_add(1).unwrap();
166167// if we haven't already waited during the current generation
168if self.generation == lock.generation {
169 lock.waiters = lock.waiters.checked_add(1).unwrap();
170 }
171172 LatchWaiter {
173 inner: Arc::clone(&self.inner),
174 generation: self.generation,
175 }
176 }
177}
178179impl std::ops::Drop for LatchCounter {
180fn drop(&mut self) {
181let mut lock = self.inner.lock.lock().unwrap();
182 lock.total_counters = lock.total_counters.checked_sub(1).unwrap();
183184// if we haven't already counted down during the current generation
185if self.generation == lock.generation {
186 lock.counters = lock.counters.checked_sub(1).unwrap();
187 }
188189// if this is the last counter, notify the waiters
190if lock.counters == 0 {
191self.inner.cond.notify_all();
192 }
193 }
194}
195196impl std::ops::Drop for LatchWaiter {
197fn drop(&mut self) {
198let mut lock = self.inner.lock.lock().unwrap();
199 lock.total_waiters = lock.total_waiters.checked_sub(1).unwrap();
200201// if we haven't already waited during the current generation
202if self.generation == lock.generation {
203 lock.waiters = lock.waiters.checked_sub(1).unwrap();
204 }
205206// if this is the last waiter and there are no more counters, start the next generation
207if lock.waiters == 0 && lock.counters == 0 {
208 lock.advance_generation();
209 }
210 }
211}
212213#[cfg(test)]
214mod tests {
215use std::time::Duration;
216217use atomic_refcell::AtomicRefCell;
218use rand::{Rng, SeedableRng};
219220use super::*;
221222#[test]
223fn test_clone() {
224let (mut counter, mut waiter) = build_count_down_latch();
225let (mut counter_clone, mut waiter_clone) = (counter.clone(), waiter.clone());
226227 counter.count_down();
228 counter_clone.count_down();
229 waiter.wait();
230 waiter_clone.wait();
231 }
232233#[test]
234fn test_clone_before_countdown() {
235let (mut counter, mut waiter) = build_count_down_latch();
236237// the cloned counter will also need to count down for the current generation
238let mut counter_clone = counter.clone();
239 counter.count_down();
240 counter_clone.count_down();
241 waiter.wait();
242243 counter.count_down();
244 counter_clone.count_down();
245 waiter.wait();
246247let (mut counter, mut waiter) = build_count_down_latch();
248249// the cloned waiter will also need to wait for the current generation
250let mut waiter_clone = waiter.clone();
251 counter.count_down();
252 waiter.wait();
253 waiter_clone.wait();
254255 counter.count_down();
256 waiter.wait();
257 waiter_clone.wait();
258 }
259260#[test]
261fn test_clone_after_countdown() {
262let (mut counter, mut waiter) = build_count_down_latch();
263264 counter.count_down();
265// the cloned counter will also be considered "counted down" for the current generation
266let mut counter_clone = counter.clone();
267// if the cloned counter did count down here, it would panic
268waiter.wait();
269270 counter.count_down();
271 counter_clone.count_down();
272 waiter.wait();
273274let (mut counter, mut waiter) = build_count_down_latch();
275let mut waiter_clone = waiter.clone();
276277 counter.count_down();
278 waiter.wait();
279// the cloned waiter will also be considered "waited" for the current generation
280let mut waiter_clone_2 = waiter.clone();
281// if the cloned waiter did wait here, it would be waiting for the next generation
282waiter_clone.wait();
283284 counter.count_down();
285 waiter.wait();
286 waiter_clone.wait();
287 waiter_clone_2.wait();
288 }
289290#[test]
291 #[should_panic]
292fn test_double_count() {
293let (mut counter, mut _waiter) = build_count_down_latch();
294 counter.count_down();
295 counter.count_down();
296 }
297298#[test]
299fn test_single_thread() {
300let (mut counter, mut waiter) = build_count_down_latch();
301302 counter.count_down();
303 waiter.wait();
304 counter.count_down();
305 waiter.wait();
306 counter.count_down();
307 waiter.wait();
308309let mut waiter_clone = waiter.clone();
310311 counter.count_down();
312 waiter.wait();
313 waiter_clone.wait();
314315 counter.count_down();
316 waiter.wait();
317 waiter_clone.wait();
318 }
319320#[test]
321fn test_multi_thread() {
322let (mut exclusive_counter, mut exclusive_waiter) = build_count_down_latch();
323let (mut shared_counter, mut shared_waiter) = build_count_down_latch();
324let repeat = 30;
325326let lock = Arc::new(AtomicRefCell::new(()));
327let lock_clone = Arc::clone(&lock);
328329// The goal of this test is to make sure that the new threads alternate with the main thread
330 // to access the atomic refcell. The new threads each hold on to a shared borrow of the
331 // atomic refcell for ~5 ms, then the main thread gets an exclusive borrow for ~5 ms,
332 // repeating. If these time slices ever overlap, then either a shared or exclusive borrow
333 // will cause a panic and the test will fail. Randomness is added to the sleeps to vary the
334 // order in which threads wait and count down, to try to cover more edge cases.
335336let thread_fn = move |seed| {
337let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
338339for _ in 0..repeat {
340// wait for the main thread to be done with its exclusive borrow
341std::thread::sleep(Duration::from_millis(5));
342 exclusive_waiter.wait();
343 {
344// a shared borrow for a duration in the range of 0-10 ms
345let _x = lock_clone.borrow();
346 std::thread::sleep(Duration::from_millis(rng.random_range(0..10)));
347 }
348 shared_counter.count_down();
349 }
350 };
351352// start 5 threads
353let handles: Vec<_> = (0..5)
354 .map(|seed| {
355let mut f = thread_fn.clone();
356 std::thread::spawn(move || f(seed))
357 })
358 .collect();
359 std::mem::drop(thread_fn);
360361let mut rng = rand::rngs::StdRng::seed_from_u64(100);
362363for _ in 0..repeat {
364 {
365// an exclusive borrow for a duration in the range of 0-10 ms
366let _x = lock.borrow_mut();
367 std::thread::sleep(Duration::from_millis(rng.random_range(0..10)));
368 }
369 exclusive_counter.count_down();
370// wait for the other threads to be done with their shared borrow
371std::thread::sleep(Duration::from_millis(5));
372 shared_waiter.wait();
373 }
374375for h in handles {
376 h.join().unwrap();
377 }
378 }
379}