scheduler/sync/
simple_latch.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicU32, Ordering};
3
4use nix::errno::Errno;
5
6#[derive(Debug)]
21pub struct Latch {
22 latch_gen: Arc<AtomicU32>,
24}
25
26#[derive(Debug, Clone)]
30pub struct LatchWaiter {
31 waiter_gen: u32,
33 latch_gen: Arc<AtomicU32>,
35 spin_yield: bool,
37}
38
39impl Latch {
40 pub fn new() -> Self {
42 Self {
43 latch_gen: Arc::new(AtomicU32::new(0)),
44 }
45 }
46
47 pub fn waiter(&mut self, spin_yield: bool) -> LatchWaiter {
55 LatchWaiter {
56 waiter_gen: self.latch_gen.load(Ordering::Relaxed),
59 latch_gen: Arc::clone(&self.latch_gen),
60 spin_yield,
61 }
62 }
63
64 pub fn open(&mut self) {
66 self.latch_gen.fetch_add(1, Ordering::Release);
68
69 libc_futex(
70 &self.latch_gen,
71 libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG,
72 i32::MAX as u32,
76 None,
77 None,
78 0,
79 )
80 .expect("FUTEX_WAKE failed");
81 }
82}
83
84impl Default for Latch {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl LatchWaiter {
91 pub fn wait(&mut self) {
93 loop {
94 let latch_gen = self.latch_gen.load(Ordering::Acquire);
95
96 match latch_gen.wrapping_sub(self.waiter_gen) {
97 1 => break,
99 0 => {}
101 _ => panic!("Latch has been opened multiple times without us waiting"),
103 }
104
105 if !self.spin_yield {
106 let rv = libc_futex(
107 &self.latch_gen,
108 libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG,
109 latch_gen,
110 None,
111 None,
112 0,
113 );
114 assert!(
115 matches!(rv, Ok(_) | Err(Errno::EAGAIN | Errno::EINTR)),
116 "FUTEX_WAIT failed with {rv:?}"
117 );
118 } else {
119 std::hint::spin_loop();
123 std::thread::yield_now();
124 }
125 }
126
127 self.waiter_gen = self.waiter_gen.wrapping_add(1);
128 }
129}
130
131pub fn libc_futex(
135 uaddr: &AtomicU32,
136 op: core::ffi::c_int,
137 val: u32,
138 utime: Option<&libc::timespec>,
139 uaddr2: Option<&AtomicU32>,
140 val3: u32,
141) -> Result<core::ffi::c_int, Errno> {
142 let uaddr: *mut u32 = uaddr.as_ptr();
143 let utime: *const libc::timespec = utime
144 .map(std::ptr::from_ref)
145 .unwrap_or(core::ptr::null_mut());
146 let uaddr2: *mut u32 = uaddr2
147 .map(AtomicU32::as_ptr)
148 .unwrap_or(core::ptr::null_mut());
149
150 let rv = unsafe { libc::syscall(libc::SYS_futex, uaddr, op, val, utime, uaddr2, val3) };
151
152 if rv >= 0 {
153 Ok(rv.try_into().expect("futex() returned invalid int"))
156 } else {
157 let errno = unsafe { *libc::__errno_location() };
158 debug_assert_eq!(rv, -1);
159 Err(Errno::from_raw(errno))
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use std::thread::sleep;
166 use std::time::{Duration, Instant};
167
168 use atomic_refcell::AtomicRefCell;
169
170 use super::*;
171
172 #[test]
173 fn test_simple() {
174 let mut latch = Latch::new();
175 let mut waiter = latch.waiter(false);
176
177 latch.open();
178 waiter.wait();
179 latch.open();
180 waiter.wait();
181 latch.open();
182 waiter.wait();
183 }
184
185 #[test]
186 #[should_panic]
187 fn test_multiple_open() {
188 let mut latch = Latch::new();
189 let mut waiter = latch.waiter(false);
190
191 latch.open();
192 waiter.wait();
193 latch.open();
194 latch.open();
195
196 waiter.wait();
198 }
199
200 #[test]
201 fn test_blocking() {
202 let mut latch = Latch::new();
203 let mut waiter = latch.waiter(false);
204
205 let t = std::thread::spawn(move || {
206 let start = Instant::now();
207 waiter.wait();
208 start.elapsed()
209 });
210
211 let sleep_duration = Duration::from_millis(200);
212 sleep(sleep_duration);
213 latch.open();
214
215 let wait_duration = t.join().unwrap();
216
217 let threshold = Duration::from_millis(40);
218 assert!(wait_duration > sleep_duration - threshold);
219 assert!(wait_duration < sleep_duration + threshold);
220 }
221
222 #[test]
223 fn test_clone() {
224 let mut latch = Latch::new();
225 let mut waiter = latch.waiter(false);
226
227 latch.open();
228 waiter.wait();
229 latch.open();
230 waiter.wait();
231
232 let mut waiter_2 = waiter.clone();
234
235 latch.open();
236 waiter.wait();
237 waiter_2.wait();
238 }
239
240 #[test]
241 fn test_ping_pong() {
242 let mut latch_1 = Latch::new();
243 let mut latch_2 = Latch::new();
244 let mut waiter_1 = latch_1.waiter(true);
245 let mut waiter_2 = latch_2.waiter(false);
246
247 let counter = Arc::new(AtomicRefCell::new(0));
248 let counter_clone = Arc::clone(&counter);
249
250 fn latch_loop(
251 latch: &mut Latch,
252 waiter: &mut LatchWaiter,
253 counter: &Arc<AtomicRefCell<usize>>,
254 iterations: usize,
255 ) {
256 for _ in 0..iterations {
257 waiter.wait();
258 *counter.borrow_mut() += 1;
259 latch.open();
260 }
261 }
262
263 let t = std::thread::spawn(move || {
264 latch_loop(&mut latch_2, &mut waiter_1, &counter_clone, 100);
265 });
266
267 latch_1.open();
268 latch_loop(&mut latch_1, &mut waiter_2, &counter, 100);
269
270 t.join().unwrap();
271
272 assert_eq!(*counter.borrow(), 200);
273 }
274}