scheduler/sync/
simple_latch.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicU32, Ordering};
3
4use nix::errno::Errno;
5
6/// A simple reusable latch. Multiple waiters can wait for the latch to open. After opening the
7/// latch with [`open()`](Self::open), you must not open the latch again until all waiters have
8/// waited with [`wait()`](LatchWaiter::wait) on the latch. In other words, you must not call
9/// `open()` multiple times without making sure that all waiters have successfully returned from
10/// `wait()` each time. This typically requires some other synchronization to make sure that the
11/// waiters have waited. If the latch and its waiters aren't kept in sync, the waiters will usually
12/// panic, but in some cases may behave incorrectly[^note].
13///
14/// [^note]: Since this latch uses a 32-bit wrapping integer to track the positions of the latch and
15/// its waiters, calling `open()` `u32::MAX + 1` times without allowing the waiters to wait will
16/// behave as if you did not call `open()` at all.
17///
18/// The latch uses release-acquire ordering, so any changes made before an `open()` should be
19/// visible in other threads after a `wait()` returns.
20#[derive(Debug)]
21pub struct Latch {
22    /// The generation of the latch.
23    latch_gen: Arc<AtomicU32>,
24}
25
26/// A waiter that waits for the latch to open. A waiter for a latch can be created with
27/// [`waiter()`](Latch::waiter). Cloning a waiter will create a new waiter with the same
28/// state/generation as the existing waiter.
29#[derive(Debug, Clone)]
30pub struct LatchWaiter {
31    /// The generation of this waiter.
32    waiter_gen: u32,
33    /// The read-only generation of the latch.
34    latch_gen: Arc<AtomicU32>,
35    /// Should we sched_yield in a spinloop indefinitely rather than futex-wait?
36    spin_yield: bool,
37}
38
39impl Latch {
40    /// Create a new latch.
41    pub fn new() -> Self {
42        Self {
43            latch_gen: Arc::new(AtomicU32::new(0)),
44        }
45    }
46
47    /// Get a new waiter for this latch. The new waiter will have the same generation as the latch,
48    /// meaning that a single [`wait()`](LatchWaiter::wait) will block the waiter until the next
49    /// latch [`open()`](Self::open).
50    ///
51    /// If `spin_yield` is `true`, the waiter will `sched_yield` in a spinloop indefinitely. If
52    /// `spin_yield` is `false`, the waiter will futex-wait. Setting to `true` may improve
53    /// performance in some workloads.
54    pub fn waiter(&mut self, spin_yield: bool) -> LatchWaiter {
55        LatchWaiter {
56            // we're the only one who can mutate the atomic,
57            // so there's no race condition here
58            waiter_gen: self.latch_gen.load(Ordering::Relaxed),
59            latch_gen: Arc::clone(&self.latch_gen),
60            spin_yield,
61        }
62    }
63
64    /// Open the latch.
65    pub fn open(&mut self) {
66        // the addition is wrapping
67        self.latch_gen.fetch_add(1, Ordering::Release);
68
69        libc_futex(
70            &self.latch_gen,
71            libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG,
72            // the man page says to use INT_MAX which is weird since this is a u32, but the kernel
73            // `do_futex` function implicitly casts this to an int when passing it to `futex_wake`
74            // (as of linux 6.6.8), so this seems like the right value to use
75            i32::MAX as u32,
76            None,
77            None,
78            0,
79        )
80        .expect("FUTEX_WAKE failed");
81    }
82}
83
84impl Default for Latch {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl LatchWaiter {
91    /// Wait for the latch to open.
92    pub fn wait(&mut self) {
93        loop {
94            let latch_gen = self.latch_gen.load(Ordering::Acquire);
95
96            match latch_gen.wrapping_sub(self.waiter_gen) {
97                // the latch has been opened and we can advance to the next generation
98                1 => break,
99                // the latch has not been opened and we're at the same generation
100                0 => {}
101                // the latch has been opened multiple times and we haven't been kept in sync
102                _ => panic!("Latch has been opened multiple times without us waiting"),
103            }
104
105            if !self.spin_yield {
106                let rv = libc_futex(
107                    &self.latch_gen,
108                    libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG,
109                    latch_gen,
110                    None,
111                    None,
112                    0,
113                );
114                assert!(
115                    matches!(rv, Ok(_) | Err(Errno::EAGAIN | Errno::EINTR)),
116                    "FUTEX_WAIT failed with {rv:?}"
117                );
118            } else {
119                // we don't know if a pause instruction is beneficial or not here, but it doesn't
120                // seem to hurt performance
121                // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-9/pause-intrinsic.html
122                std::hint::spin_loop();
123                std::thread::yield_now();
124            }
125        }
126
127        self.waiter_gen = self.waiter_gen.wrapping_add(1);
128    }
129}
130
131// Perform a futex operation using libc. Miri only understands futex syscalls made through the
132// [`libc::syscall`] function so we need to use it here. I don't see any reason to mark this as
133// "unsafe", but I didn't look through all of the possible futex operations.
134pub fn libc_futex(
135    uaddr: &AtomicU32,
136    op: core::ffi::c_int,
137    val: u32,
138    utime: Option<&libc::timespec>,
139    uaddr2: Option<&AtomicU32>,
140    val3: u32,
141) -> Result<core::ffi::c_int, Errno> {
142    let uaddr: *mut u32 = uaddr.as_ptr();
143    let utime: *const libc::timespec = utime
144        .map(std::ptr::from_ref)
145        .unwrap_or(core::ptr::null_mut());
146    let uaddr2: *mut u32 = uaddr2
147        .map(AtomicU32::as_ptr)
148        .unwrap_or(core::ptr::null_mut());
149
150    let rv = unsafe { libc::syscall(libc::SYS_futex, uaddr, op, val, utime, uaddr2, val3) };
151
152    if rv >= 0 {
153        // the linux x86-64 syscall implementation returns an int so I don't think this should ever
154        // fail
155        Ok(rv.try_into().expect("futex() returned invalid int"))
156    } else {
157        let errno = unsafe { *libc::__errno_location() };
158        debug_assert_eq!(rv, -1);
159        Err(Errno::from_raw(errno))
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use std::thread::sleep;
166    use std::time::{Duration, Instant};
167
168    use atomic_refcell::AtomicRefCell;
169
170    use super::*;
171
172    #[test]
173    fn test_simple() {
174        let mut latch = Latch::new();
175        let mut waiter = latch.waiter(false);
176
177        latch.open();
178        waiter.wait();
179        latch.open();
180        waiter.wait();
181        latch.open();
182        waiter.wait();
183    }
184
185    #[test]
186    #[should_panic]
187    fn test_multiple_open() {
188        let mut latch = Latch::new();
189        let mut waiter = latch.waiter(false);
190
191        latch.open();
192        waiter.wait();
193        latch.open();
194        latch.open();
195
196        // this should panic
197        waiter.wait();
198    }
199
200    #[test]
201    fn test_blocking() {
202        let mut latch = Latch::new();
203        let mut waiter = latch.waiter(false);
204
205        let t = std::thread::spawn(move || {
206            let start = Instant::now();
207            waiter.wait();
208            start.elapsed()
209        });
210
211        let sleep_duration = Duration::from_millis(200);
212        sleep(sleep_duration);
213        latch.open();
214
215        let wait_duration = t.join().unwrap();
216
217        let threshold = Duration::from_millis(40);
218        assert!(wait_duration > sleep_duration - threshold);
219        assert!(wait_duration < sleep_duration + threshold);
220    }
221
222    #[test]
223    fn test_clone() {
224        let mut latch = Latch::new();
225        let mut waiter = latch.waiter(false);
226
227        latch.open();
228        waiter.wait();
229        latch.open();
230        waiter.wait();
231
232        // new waiter should have the same generation
233        let mut waiter_2 = waiter.clone();
234
235        latch.open();
236        waiter.wait();
237        waiter_2.wait();
238    }
239
240    #[test]
241    fn test_ping_pong() {
242        let mut latch_1 = Latch::new();
243        let mut latch_2 = Latch::new();
244        let mut waiter_1 = latch_1.waiter(true);
245        let mut waiter_2 = latch_2.waiter(false);
246
247        let counter = Arc::new(AtomicRefCell::new(0));
248        let counter_clone = Arc::clone(&counter);
249
250        fn latch_loop(
251            latch: &mut Latch,
252            waiter: &mut LatchWaiter,
253            counter: &Arc<AtomicRefCell<usize>>,
254            iterations: usize,
255        ) {
256            for _ in 0..iterations {
257                waiter.wait();
258                *counter.borrow_mut() += 1;
259                latch.open();
260            }
261        }
262
263        let t = std::thread::spawn(move || {
264            latch_loop(&mut latch_2, &mut waiter_1, &counter_clone, 100);
265        });
266
267        latch_1.open();
268        latch_loop(&mut latch_1, &mut waiter_2, &counter, 100);
269
270        t.join().unwrap();
271
272        assert_eq!(*counter.borrow(), 200);
273    }
274}