scheduler/sync/
count_down_latch.rs1use std::sync::{Arc, Condvar, Mutex};
2
3#[derive(Debug)]
11pub struct LatchCounter {
12 inner: Arc<LatchInner>,
13 generation: usize,
15}
16
17#[derive(Debug)]
24pub struct LatchWaiter {
25 inner: Arc<LatchInner>,
26 generation: usize,
28}
29
30#[derive(Debug)]
31struct LatchInner {
32 lock: Mutex<LatchState>,
33 cond: Condvar,
34}
35
36#[derive(Debug)]
37struct LatchState {
38 generation: usize,
40 counters: usize,
42 waiters: usize,
44 total_counters: usize,
46 total_waiters: usize,
48}
49
50pub fn build_count_down_latch() -> (LatchCounter, LatchWaiter) {
53 let inner = Arc::new(LatchInner {
54 lock: Mutex::new(LatchState {
55 generation: 0,
56 counters: 1,
57 waiters: 1,
58 total_counters: 1,
59 total_waiters: 1,
60 }),
61 cond: Condvar::new(),
62 });
63
64 let counter = LatchCounter {
65 inner: Arc::clone(&inner),
66 generation: 0,
67 };
68
69 let waiter = LatchWaiter {
70 inner,
71 generation: 0,
72 };
73
74 (counter, waiter)
75}
76
77impl LatchState {
78 pub fn advance_generation(&mut self) {
79 debug_assert_eq!(self.counters, 0);
80 debug_assert_eq!(self.waiters, 0);
81 self.counters = self.total_counters;
82 self.waiters = self.total_waiters;
83 self.generation = self.generation.wrapping_add(1);
84 }
85}
86
87impl LatchCounter {
88 pub fn count_down(&mut self) {
92 let counters;
93 {
94 let mut lock = self.inner.lock.lock().unwrap();
95
96 if self.generation != lock.generation {
97 let latch_gen = lock.generation;
98 std::mem::drop(lock);
99 panic!(
100 "Counter generation does not match latch generation ({} != {})",
101 self.generation, latch_gen
102 );
103 }
104
105 lock.counters = lock.counters.checked_sub(1).unwrap();
106 counters = lock.counters;
107 }
108
109 if counters == 0 {
111 self.inner.cond.notify_all();
112 }
113
114 self.generation = self.generation.wrapping_add(1);
115 }
116}
117
118impl LatchWaiter {
119 pub fn wait(&mut self) {
122 {
123 let lock = self.inner.lock.lock().unwrap();
124
125 let mut lock = self
126 .inner
127 .cond
128 .wait_while(lock, |x| self.generation != x.generation || x.counters > 0)
130 .unwrap();
131
132 lock.waiters = lock.waiters.checked_sub(1).unwrap();
133
134 if lock.waiters == 0 {
137 lock.advance_generation();
138 }
139 }
140
141 self.generation = self.generation.wrapping_add(1);
142 }
143}
144
145impl Clone for LatchCounter {
146 fn clone(&self) -> Self {
147 let mut lock = self.inner.lock.lock().unwrap();
148 lock.total_counters = lock.total_counters.checked_add(1).unwrap();
149
150 if self.generation == lock.generation {
152 lock.counters = lock.counters.checked_add(1).unwrap();
153 }
154
155 LatchCounter {
156 inner: Arc::clone(&self.inner),
157 generation: self.generation,
158 }
159 }
160}
161
162impl Clone for LatchWaiter {
163 fn clone(&self) -> Self {
164 let mut lock = self.inner.lock.lock().unwrap();
165 lock.total_waiters = lock.total_waiters.checked_add(1).unwrap();
166
167 if self.generation == lock.generation {
169 lock.waiters = lock.waiters.checked_add(1).unwrap();
170 }
171
172 LatchWaiter {
173 inner: Arc::clone(&self.inner),
174 generation: self.generation,
175 }
176 }
177}
178
179impl std::ops::Drop for LatchCounter {
180 fn drop(&mut self) {
181 let mut lock = self.inner.lock.lock().unwrap();
182 lock.total_counters = lock.total_counters.checked_sub(1).unwrap();
183
184 if self.generation == lock.generation {
186 lock.counters = lock.counters.checked_sub(1).unwrap();
187 }
188
189 if lock.counters == 0 {
191 self.inner.cond.notify_all();
192 }
193 }
194}
195
196impl std::ops::Drop for LatchWaiter {
197 fn drop(&mut self) {
198 let mut lock = self.inner.lock.lock().unwrap();
199 lock.total_waiters = lock.total_waiters.checked_sub(1).unwrap();
200
201 if self.generation == lock.generation {
203 lock.waiters = lock.waiters.checked_sub(1).unwrap();
204 }
205
206 if lock.waiters == 0 && lock.counters == 0 {
208 lock.advance_generation();
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use std::time::Duration;
216
217 use atomic_refcell::AtomicRefCell;
218 use rand::{Rng, SeedableRng};
219
220 use super::*;
221
222 #[test]
223 fn test_clone() {
224 let (mut counter, mut waiter) = build_count_down_latch();
225 let (mut counter_clone, mut waiter_clone) = (counter.clone(), waiter.clone());
226
227 counter.count_down();
228 counter_clone.count_down();
229 waiter.wait();
230 waiter_clone.wait();
231 }
232
233 #[test]
234 fn test_clone_before_countdown() {
235 let (mut counter, mut waiter) = build_count_down_latch();
236
237 let mut counter_clone = counter.clone();
239 counter.count_down();
240 counter_clone.count_down();
241 waiter.wait();
242
243 counter.count_down();
244 counter_clone.count_down();
245 waiter.wait();
246
247 let (mut counter, mut waiter) = build_count_down_latch();
248
249 let mut waiter_clone = waiter.clone();
251 counter.count_down();
252 waiter.wait();
253 waiter_clone.wait();
254
255 counter.count_down();
256 waiter.wait();
257 waiter_clone.wait();
258 }
259
260 #[test]
261 fn test_clone_after_countdown() {
262 let (mut counter, mut waiter) = build_count_down_latch();
263
264 counter.count_down();
265 let mut counter_clone = counter.clone();
267 waiter.wait();
269
270 counter.count_down();
271 counter_clone.count_down();
272 waiter.wait();
273
274 let (mut counter, mut waiter) = build_count_down_latch();
275 let mut waiter_clone = waiter.clone();
276
277 counter.count_down();
278 waiter.wait();
279 let mut waiter_clone_2 = waiter.clone();
281 waiter_clone.wait();
283
284 counter.count_down();
285 waiter.wait();
286 waiter_clone.wait();
287 waiter_clone_2.wait();
288 }
289
290 #[test]
291 #[should_panic]
292 fn test_double_count() {
293 let (mut counter, mut _waiter) = build_count_down_latch();
294 counter.count_down();
295 counter.count_down();
296 }
297
298 #[test]
299 fn test_single_thread() {
300 let (mut counter, mut waiter) = build_count_down_latch();
301
302 counter.count_down();
303 waiter.wait();
304 counter.count_down();
305 waiter.wait();
306 counter.count_down();
307 waiter.wait();
308
309 let mut waiter_clone = waiter.clone();
310
311 counter.count_down();
312 waiter.wait();
313 waiter_clone.wait();
314
315 counter.count_down();
316 waiter.wait();
317 waiter_clone.wait();
318 }
319
320 #[test]
321 fn test_multi_thread() {
322 let (mut exclusive_counter, mut exclusive_waiter) = build_count_down_latch();
323 let (mut shared_counter, mut shared_waiter) = build_count_down_latch();
324 let repeat = 30;
325
326 let lock = Arc::new(AtomicRefCell::new(()));
327 let lock_clone = Arc::clone(&lock);
328
329 let thread_fn = move |seed| {
337 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
338
339 for _ in 0..repeat {
340 std::thread::sleep(Duration::from_millis(5));
342 exclusive_waiter.wait();
343 {
344 let _x = lock_clone.borrow();
346 std::thread::sleep(Duration::from_millis(rng.random_range(0..10)));
347 }
348 shared_counter.count_down();
349 }
350 };
351
352 let handles: Vec<_> = (0..5)
354 .map(|seed| {
355 let mut f = thread_fn.clone();
356 std::thread::spawn(move || f(seed))
357 })
358 .collect();
359 std::mem::drop(thread_fn);
360
361 let mut rng = rand::rngs::StdRng::seed_from_u64(100);
362
363 for _ in 0..repeat {
364 {
365 let _x = lock.borrow_mut();
367 std::thread::sleep(Duration::from_millis(rng.random_range(0..10)));
368 }
369 exclusive_counter.count_down();
370 std::thread::sleep(Duration::from_millis(5));
372 shared_waiter.wait();
373 }
374
375 for h in handles {
376 h.join().unwrap();
377 }
378 }
379}