scheduler/sync/
count_down_latch.rs

1use std::sync::{Arc, Condvar, Mutex};
2
3/// A latch counter.
4///
5/// If a counter is cloned, it will inherit the counter's state for the current generation. For
6/// example if a counter is cloned after it has already counted down, then the new counter will also
7/// be treated as if it had already counted down in the current generation. If a counter is cloned
8/// before it has counted down, then the new counter will also need to count down in the current
9/// generation.
10#[derive(Debug)]
11pub struct LatchCounter {
12    inner: Arc<LatchInner>,
13    /// An ID for this counter's count-down round.
14    generation: usize,
15}
16
17/// A latch waiter.
18///
19/// If a waiter is cloned, it will inherit the waiter's state for the current generation. For
20/// example if a waiter is cloned after it has already waited, then the new waiter will also be
21/// treated as if it had already waited in the current generation. If a waiter is cloned before it
22/// has waited, then the new waiter will also need to wait in the current generation.
23#[derive(Debug)]
24pub struct LatchWaiter {
25    inner: Arc<LatchInner>,
26    /// An ID for this waiter's count-down round.
27    generation: usize,
28}
29
30#[derive(Debug)]
31struct LatchInner {
32    lock: Mutex<LatchState>,
33    cond: Condvar,
34}
35
36#[derive(Debug)]
37struct LatchState {
38    /// The current latch "round".
39    generation: usize,
40    /// Number of counters remaining.
41    counters: usize,
42    /// Number of waiters remaining.
43    waiters: usize,
44    /// Total number of counters.
45    total_counters: usize,
46    /// Total number of waiters.
47    total_waiters: usize,
48}
49
50/// Build a latch counter and waiter. The counter and waiter can be cloned to create new counters
51/// and waiters.
52pub fn build_count_down_latch() -> (LatchCounter, LatchWaiter) {
53    let inner = Arc::new(LatchInner {
54        lock: Mutex::new(LatchState {
55            generation: 0,
56            counters: 1,
57            waiters: 1,
58            total_counters: 1,
59            total_waiters: 1,
60        }),
61        cond: Condvar::new(),
62    });
63
64    let counter = LatchCounter {
65        inner: Arc::clone(&inner),
66        generation: 0,
67    };
68
69    let waiter = LatchWaiter {
70        inner,
71        generation: 0,
72    };
73
74    (counter, waiter)
75}
76
77impl LatchState {
78    pub fn advance_generation(&mut self) {
79        debug_assert_eq!(self.counters, 0);
80        debug_assert_eq!(self.waiters, 0);
81        self.counters = self.total_counters;
82        self.waiters = self.total_waiters;
83        self.generation = self.generation.wrapping_add(1);
84    }
85}
86
87impl LatchCounter {
88    /// Decrement the latch count and wake the waiters if the count reaches 0. This must not be
89    /// called more than once per generation (must not be called again until all of the waiters have
90    /// returned from their [`LatchWaiter::wait()`] calls), otherwise it will panic.
91    pub fn count_down(&mut self) {
92        let counters;
93        {
94            let mut lock = self.inner.lock.lock().unwrap();
95
96            if self.generation != lock.generation {
97                let latch_gen = lock.generation;
98                std::mem::drop(lock);
99                panic!(
100                    "Counter generation does not match latch generation ({} != {})",
101                    self.generation, latch_gen
102                );
103            }
104
105            lock.counters = lock.counters.checked_sub(1).unwrap();
106            counters = lock.counters;
107        }
108
109        // if this is the last counter, notify the waiters
110        if counters == 0 {
111            self.inner.cond.notify_all();
112        }
113
114        self.generation = self.generation.wrapping_add(1);
115    }
116}
117
118impl LatchWaiter {
119    /// Wait for the latch count to reach 0. If the latch count has already reached 0 for the
120    /// current genration, this will return immediately.
121    pub fn wait(&mut self) {
122        {
123            let lock = self.inner.lock.lock().unwrap();
124
125            let mut lock = self
126                .inner
127                .cond
128                // wait until we're in the active generation and all counters have counted down
129                .wait_while(lock, |x| self.generation != x.generation || x.counters > 0)
130                .unwrap();
131
132            lock.waiters = lock.waiters.checked_sub(1).unwrap();
133
134            // if this is the last waiter (and we already know that there are no more counters), start
135            // the next generation
136            if lock.waiters == 0 {
137                lock.advance_generation();
138            }
139        }
140
141        self.generation = self.generation.wrapping_add(1);
142    }
143}
144
145impl Clone for LatchCounter {
146    fn clone(&self) -> Self {
147        let mut lock = self.inner.lock.lock().unwrap();
148        lock.total_counters = lock.total_counters.checked_add(1).unwrap();
149
150        // if we haven't already counted down during the current generation
151        if self.generation == lock.generation {
152            lock.counters = lock.counters.checked_add(1).unwrap();
153        }
154
155        LatchCounter {
156            inner: Arc::clone(&self.inner),
157            generation: self.generation,
158        }
159    }
160}
161
162impl Clone for LatchWaiter {
163    fn clone(&self) -> Self {
164        let mut lock = self.inner.lock.lock().unwrap();
165        lock.total_waiters = lock.total_waiters.checked_add(1).unwrap();
166
167        // if we haven't already waited during the current generation
168        if self.generation == lock.generation {
169            lock.waiters = lock.waiters.checked_add(1).unwrap();
170        }
171
172        LatchWaiter {
173            inner: Arc::clone(&self.inner),
174            generation: self.generation,
175        }
176    }
177}
178
179impl std::ops::Drop for LatchCounter {
180    fn drop(&mut self) {
181        let mut lock = self.inner.lock.lock().unwrap();
182        lock.total_counters = lock.total_counters.checked_sub(1).unwrap();
183
184        // if we haven't already counted down during the current generation
185        if self.generation == lock.generation {
186            lock.counters = lock.counters.checked_sub(1).unwrap();
187        }
188
189        // if this is the last counter, notify the waiters
190        if lock.counters == 0 {
191            self.inner.cond.notify_all();
192        }
193    }
194}
195
196impl std::ops::Drop for LatchWaiter {
197    fn drop(&mut self) {
198        let mut lock = self.inner.lock.lock().unwrap();
199        lock.total_waiters = lock.total_waiters.checked_sub(1).unwrap();
200
201        // if we haven't already waited during the current generation
202        if self.generation == lock.generation {
203            lock.waiters = lock.waiters.checked_sub(1).unwrap();
204        }
205
206        // if this is the last waiter and there are no more counters, start the next generation
207        if lock.waiters == 0 && lock.counters == 0 {
208            lock.advance_generation();
209        }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use std::time::Duration;
216
217    use atomic_refcell::AtomicRefCell;
218    use rand::{Rng, SeedableRng};
219
220    use super::*;
221
222    #[test]
223    fn test_clone() {
224        let (mut counter, mut waiter) = build_count_down_latch();
225        let (mut counter_clone, mut waiter_clone) = (counter.clone(), waiter.clone());
226
227        counter.count_down();
228        counter_clone.count_down();
229        waiter.wait();
230        waiter_clone.wait();
231    }
232
233    #[test]
234    fn test_clone_before_countdown() {
235        let (mut counter, mut waiter) = build_count_down_latch();
236
237        // the cloned counter will also need to count down for the current generation
238        let mut counter_clone = counter.clone();
239        counter.count_down();
240        counter_clone.count_down();
241        waiter.wait();
242
243        counter.count_down();
244        counter_clone.count_down();
245        waiter.wait();
246
247        let (mut counter, mut waiter) = build_count_down_latch();
248
249        // the cloned waiter will also need to wait for the current generation
250        let mut waiter_clone = waiter.clone();
251        counter.count_down();
252        waiter.wait();
253        waiter_clone.wait();
254
255        counter.count_down();
256        waiter.wait();
257        waiter_clone.wait();
258    }
259
260    #[test]
261    fn test_clone_after_countdown() {
262        let (mut counter, mut waiter) = build_count_down_latch();
263
264        counter.count_down();
265        // the cloned counter will also be considered "counted down" for the current generation
266        let mut counter_clone = counter.clone();
267        // if the cloned counter did count down here, it would panic
268        waiter.wait();
269
270        counter.count_down();
271        counter_clone.count_down();
272        waiter.wait();
273
274        let (mut counter, mut waiter) = build_count_down_latch();
275        let mut waiter_clone = waiter.clone();
276
277        counter.count_down();
278        waiter.wait();
279        // the cloned waiter will also be considered "waited" for the current generation
280        let mut waiter_clone_2 = waiter.clone();
281        // if the cloned waiter did wait here, it would be waiting for the next generation
282        waiter_clone.wait();
283
284        counter.count_down();
285        waiter.wait();
286        waiter_clone.wait();
287        waiter_clone_2.wait();
288    }
289
290    #[test]
291    #[should_panic]
292    fn test_double_count() {
293        let (mut counter, mut _waiter) = build_count_down_latch();
294        counter.count_down();
295        counter.count_down();
296    }
297
298    #[test]
299    fn test_single_thread() {
300        let (mut counter, mut waiter) = build_count_down_latch();
301
302        counter.count_down();
303        waiter.wait();
304        counter.count_down();
305        waiter.wait();
306        counter.count_down();
307        waiter.wait();
308
309        let mut waiter_clone = waiter.clone();
310
311        counter.count_down();
312        waiter.wait();
313        waiter_clone.wait();
314
315        counter.count_down();
316        waiter.wait();
317        waiter_clone.wait();
318    }
319
320    #[test]
321    fn test_multi_thread() {
322        let (mut exclusive_counter, mut exclusive_waiter) = build_count_down_latch();
323        let (mut shared_counter, mut shared_waiter) = build_count_down_latch();
324        let repeat = 30;
325
326        let lock = Arc::new(AtomicRefCell::new(()));
327        let lock_clone = Arc::clone(&lock);
328
329        // The goal of this test is to make sure that the new threads alternate with the main thread
330        // to access the atomic refcell. The new threads each hold on to a shared borrow of the
331        // atomic refcell for ~5 ms, then the main thread gets an exclusive borrow for ~5 ms,
332        // repeating. If these time slices ever overlap, then either a shared or exclusive borrow
333        // will cause a panic and the test will fail. Randomness is added to the sleeps to vary the
334        // order in which threads wait and count down, to try to cover more edge cases.
335
336        let thread_fn = move |seed| {
337            let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
338
339            for _ in 0..repeat {
340                // wait for the main thread to be done with its exclusive borrow
341                std::thread::sleep(Duration::from_millis(5));
342                exclusive_waiter.wait();
343                {
344                    // a shared borrow for a duration in the range of 0-10 ms
345                    let _x = lock_clone.borrow();
346                    std::thread::sleep(Duration::from_millis(rng.random_range(0..10)));
347                }
348                shared_counter.count_down();
349            }
350        };
351
352        // start 5 threads
353        let handles: Vec<_> = (0..5)
354            .map(|seed| {
355                let mut f = thread_fn.clone();
356                std::thread::spawn(move || f(seed))
357            })
358            .collect();
359        std::mem::drop(thread_fn);
360
361        let mut rng = rand::rngs::StdRng::seed_from_u64(100);
362
363        for _ in 0..repeat {
364            {
365                // an exclusive borrow for a duration in the range of 0-10 ms
366                let _x = lock.borrow_mut();
367                std::thread::sleep(Duration::from_millis(rng.random_range(0..10)));
368            }
369            exclusive_counter.count_down();
370            // wait for the other threads to be done with their shared borrow
371            std::thread::sleep(Duration::from_millis(5));
372            shared_waiter.wait();
373        }
374
375        for h in handles {
376            h.join().unwrap();
377        }
378    }
379}