scheduler/pools/
unbounded.rs

1// When comparing a loaded value that happens to be bool,
2// assert_eq! reads better than assert!.
3#![allow(clippy::bool_assert_comparison)]
4
5use std::marker::PhantomData;
6use std::ops::Deref;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use atomic_refcell::AtomicRefCell;
11
12use crate::sync::count_down_latch::{self, build_count_down_latch};
13use crate::sync::simple_latch;
14
15// If making substantial changes to this scheduler, you should verify the compilation error message
16// for each test at the end of this file to make sure that they correctly cause the expected
17// compilation error.  This work pool unsafely transmutes the task closure lifetime, and the
18// commented tests are meant to make sure that the work pool does not allow unsound code to compile.
19// Due to lifetime sub-typing/variance, rust will sometimes allow closures with shorter or longer
20// lifetimes than we specify in the API, so the tests check to make sure the closures are invariant
21// over the lifetime and that the usage is sound.
22
23/// A task that is run by the pool threads.
24pub trait TaskFn: Fn(usize) + Send + Sync {}
25impl<T> TaskFn for T where T: Fn(usize) + Send + Sync {}
26
27/// A thread pool that runs a task on many threads. A task will run once on each thread.
28pub struct UnboundedThreadPool {
29    /// Handles for joining threads when they've exited.
30    thread_handles: Vec<std::thread::JoinHandle<()>>,
31    /// State shared between all threads.
32    shared_state: Arc<SharedState>,
33    /// A latch that is opened when the task is set. Indicates to the threads that they should start
34    /// running the task.
35    task_start_latch: simple_latch::Latch,
36    /// The main thread uses this to wait for the threads to finish running the task.
37    task_end_waiter: count_down_latch::LatchWaiter,
38}
39
40pub struct SharedState {
41    /// The task to run during the next round.
42    task: AtomicRefCell<Option<Box<dyn TaskFn>>>,
43    /// Has a thread panicked?
44    has_thread_panicked: AtomicBool,
45}
46
47impl UnboundedThreadPool {
48    pub fn new(num_threads: usize, thread_name: &str, yield_spin: bool) -> Self {
49        let shared_state = Arc::new(SharedState {
50            task: AtomicRefCell::new(None),
51            has_thread_panicked: AtomicBool::new(false),
52        });
53
54        let (task_end_counter, task_end_waiter) = build_count_down_latch();
55        let mut task_start_latch = simple_latch::Latch::new();
56
57        let mut thread_handles = Vec::new();
58
59        for i in 0..num_threads {
60            let shared_state_clone = Arc::clone(&shared_state);
61
62            // enabling spinning on the threads may improve performance under some conditions
63            // (see https://github.com/shadow/shadow/issues/2877)
64            let task_start_waiter = task_start_latch.waiter(yield_spin);
65
66            let task_end_counter_clone = task_end_counter.clone();
67
68            let handle = std::thread::Builder::new()
69                .name(thread_name.to_string())
70                .spawn(move || {
71                    work_loop(
72                        i,
73                        shared_state_clone,
74                        task_start_waiter,
75                        task_end_counter_clone,
76                    )
77                })
78                .unwrap();
79
80            thread_handles.push(handle);
81        }
82
83        Self {
84            thread_handles,
85            shared_state,
86            task_start_latch,
87            task_end_waiter,
88        }
89    }
90
91    /// Stop and join the threads.
92    pub fn join(self) {
93        // the drop handler will join the threads
94    }
95
96    fn join_internal(&mut self) {
97        // a `None` indicates that the threads should end
98        assert!(self.shared_state.task.borrow().is_none());
99
100        // only check the thread join return value if no threads have yet panicked
101        let check_for_errors = !self
102            .shared_state
103            .has_thread_panicked
104            .load(Ordering::Relaxed);
105
106        // start the threads
107        self.task_start_latch.open();
108
109        for handle in self.thread_handles.drain(..) {
110            let result = handle.join();
111            if check_for_errors {
112                result.expect("A thread panicked while stopping");
113            }
114        }
115    }
116
117    /// Create a new scope for the pool. The scope will ensure that any task run on the pool within
118    /// this scope has completed before leaving the scope.
119    //
120    // SAFETY: This works because:
121    //
122    // 1. WorkerScope<'scope> is covariant over 'scope.
123    // 2. TaskRunner<'a, 'scope> is invariant over WorkerScope<'scope>, so TaskRunner<'a, 'scope>
124    //    is invariant over 'scope.
125    // 3. FnOnce(TaskRunner<'a, 'scope>) is contravariant over TaskRunner<'a, 'scope>, so
126    //    FnOnce(TaskRunner<'a, 'scope>) is invariant over 'scope.
127    //
128    // This means that the provided scope closure cannot take a TaskRunner<'a, 'scope2> where
129    // 'scope2 is shorter than 'scope, and therefore 'scope must be as long as this function call.
130    //
131    // If TaskRunner<'a, 'scope> was covariant over 'scope, then FnOnce(TaskRunner<'a, 'scope>)
132    // would have been contravariant over 'scope. This would have allowed the user to provide a
133    // scope closure that could take a TaskRunner<'a, 'scope2> where 'scope2 is shorter than 'scope.
134    // Then when TaskRunner<'a, 'scope2>::run(...) would eventually be called, the run closure would
135    // capture data with a lifetime of only 'scope2, which would be a shorter lifetime than the
136    // scope closure's lifetime of 'scope. Then, any captured mutable references would be accessible
137    // from both the run closure and the scope closure, leading to mutable aliasing.
138    pub fn scope<'scope>(
139        &'scope mut self,
140        f: impl for<'a> FnOnce(TaskRunner<'a, 'scope>) + 'scope,
141    ) {
142        assert!(
143            !self
144                .shared_state
145                .has_thread_panicked
146                .load(Ordering::Relaxed),
147            "Attempting to use a workpool that previously panicked"
148        );
149
150        // makes sure that the task is properly cleared even if 'f' panics
151        let mut scope = WorkerScope::<'scope> {
152            pool: self,
153            _phantom: Default::default(),
154        };
155
156        let runner = TaskRunner { scope: &mut scope };
157
158        f(runner);
159    }
160}
161
162impl std::ops::Drop for UnboundedThreadPool {
163    fn drop(&mut self) {
164        self.join_internal();
165    }
166}
167
168struct WorkerScope<'scope> {
169    pool: &'scope mut UnboundedThreadPool,
170    // when we are dropped, it's like dropping the task
171    _phantom: PhantomData<Box<dyn TaskFn + 'scope>>,
172}
173
174impl std::ops::Drop for WorkerScope<'_> {
175    fn drop(&mut self) {
176        // if the task was set (if `TaskRunner::run` was called)
177        if self.pool.shared_state.task.borrow().is_some() {
178            // wait for the task to complete
179            self.pool.task_end_waiter.wait();
180
181            // clear the task
182            *self.pool.shared_state.task.borrow_mut() = None;
183
184            // generally following https://docs.rs/rayon/latest/rayon/fn.scope.html#panics
185            if self
186                .pool
187                .shared_state
188                .has_thread_panicked
189                .load(Ordering::Relaxed)
190            {
191                // we could store the thread's panic message and propagate it, but I don't think
192                // that's worth handling
193                panic!("A work thread panicked");
194            }
195        }
196    }
197}
198
199/// Allows a single task to run per pool scope.
200pub struct TaskRunner<'a, 'scope> {
201    // SAFETY: Self must be invariant over 'scope, which is why we use &mut here. See the
202    // documentation for scope() above for details.
203    scope: &'a mut WorkerScope<'scope>,
204}
205
206impl<'scope> TaskRunner<'_, 'scope> {
207    /// Run a task on the pool's threads.
208    pub fn run(self, f: impl TaskFn + 'scope) {
209        let f = Box::new(f);
210
211        // SAFETY: WorkerScope will drop this TaskFn before the end of 'scope
212        let f = unsafe {
213            std::mem::transmute::<Box<dyn TaskFn + 'scope>, Box<dyn TaskFn + 'static>>(f)
214        };
215
216        *self.scope.pool.shared_state.task.borrow_mut() = Some(f);
217
218        // we've set the task, so start the threads
219        self.scope.pool.task_start_latch.open();
220    }
221}
222
223fn work_loop(
224    thread_index: usize,
225    shared_state: Arc<SharedState>,
226    mut start_waiter: simple_latch::LatchWaiter,
227    mut end_counter: count_down_latch::LatchCounter,
228) {
229    // we don't use `catch_unwind` here for two main reasons:
230    //
231    // 1. `catch_unwind` requires that the closure is `UnwindSafe`, which means that `TaskFn` also
232    // needs to be `UnwindSafe`. This is a big restriction on the types of tasks that we could run,
233    // since it requires that there's no interior mutability in the closure. rayon seems to get
234    // around this by wrapping the closure in `AssertUnwindSafe`, under the assumption that the
235    // panic will be propagated later with `resume_unwinding`, but this is a little more difficult
236    // to reason about compared to simply avoiding `catch_unwind` altogether.
237    // https://github.com/rayon-rs/rayon/blob/c571f8ffb4f74c8c09b4e1e6d9979b71b4414d07/rayon-core/src/unwind.rs#L9
238    //
239    // 2. There is a footgun with `catch_unwind` that could cause unexpected behaviour. If the
240    // closure called `panic_any()` with a type that has a Drop implementation, and that Drop
241    // implementation panics, it will cause a panic that is not caught by the `catch_unwind`,
242    // causing the thread to panic again with no chance to clean up properly. The work pool would
243    // then deadlock. Since we don't use `catch_unwind`, the thread will instead "panic when
244    // panicking" and abort, which is a more ideal outcome.
245    // https://github.com/rust-lang/rust/issues/86027
246
247    // this will poison the workpool when it's dropped
248    struct PoisonWhenDropped<'a>(&'a SharedState);
249
250    impl std::ops::Drop for PoisonWhenDropped<'_> {
251        fn drop(&mut self) {
252            // if we panicked, then inform other threads that we panicked and allow them to exit
253            // gracefully
254            self.0.has_thread_panicked.store(true, Ordering::Relaxed);
255        }
256    }
257
258    let shared_state = shared_state.as_ref();
259    let poison_when_dropped = PoisonWhenDropped(shared_state);
260
261    loop {
262        // wait for a new task
263        start_waiter.wait();
264
265        // scope used to make sure we drop the task before counting down
266        {
267            // run the task
268            match shared_state.task.borrow().deref() {
269                Some(task) => (task)(thread_index),
270                None => {
271                    // received the sentinel value
272                    break;
273                }
274            };
275        }
276
277        // SAFETY: we do not hold any references/borrows to the task at this time
278        end_counter.count_down();
279    }
280
281    // didn't panic, so forget the poison handler and return normally
282    std::mem::forget(poison_when_dropped);
283}
284
285#[cfg(any(test, doctest))]
286mod tests {
287    use std::sync::atomic::{AtomicBool, AtomicU32};
288
289    use super::*;
290
291    #[test]
292    fn test_scope() {
293        let mut pool = UnboundedThreadPool::new(4, "worker", false);
294
295        let mut counter = 0u32;
296        for _ in 0..3 {
297            pool.scope(|_| {
298                counter += 1;
299            });
300        }
301
302        assert_eq!(counter, 3);
303    }
304
305    #[test]
306    fn test_run() {
307        let mut pool = UnboundedThreadPool::new(4, "worker", false);
308
309        let counter = AtomicU32::new(0);
310        for _ in 0..3 {
311            pool.scope(|s| {
312                s.run(|_| {
313                    counter.fetch_add(1, Ordering::SeqCst);
314                });
315            });
316        }
317
318        assert_eq!(counter.load(Ordering::SeqCst), 12);
319    }
320
321    #[test]
322    fn test_large_num_threads() {
323        let mut pool = UnboundedThreadPool::new(100, "worker", false);
324
325        let counter = AtomicU32::new(0);
326        for _ in 0..3 {
327            pool.scope(|s| {
328                s.run(|_| {
329                    counter.fetch_add(1, Ordering::SeqCst);
330                });
331            });
332        }
333
334        assert_eq!(counter.load(Ordering::SeqCst), 300);
335    }
336
337    #[test]
338    fn test_scope_runner_order() {
339        let mut pool = UnboundedThreadPool::new(1, "worker", false);
340
341        let flag = AtomicBool::new(false);
342        pool.scope(|s| {
343            s.run(|_| {
344                std::thread::sleep(std::time::Duration::from_millis(10));
345                flag.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
346                    .unwrap();
347            });
348            assert_eq!(flag.load(Ordering::SeqCst), false);
349        });
350
351        assert_eq!(flag.load(Ordering::SeqCst), true);
352    }
353
354    #[test]
355    fn test_non_aliasing_borrows() {
356        let mut pool = UnboundedThreadPool::new(4, "worker", false);
357
358        let mut counter = 0;
359        pool.scope(|s| {
360            counter += 1;
361            s.run(|_| {
362                let _x = counter;
363            });
364        });
365
366        assert_eq!(counter, 1);
367    }
368
369    // should not compile: "cannot assign to `counter` because it is borrowed"
370    /// ```compile_fail
371    /// # use shadow_rs::core::scheduler::pools::unbounded::*;
372    /// let x = 5;
373    /// let mut pool = UnboundedThreadPool::new(4, "worker", false);
374    ///
375    /// let mut counter = 0;
376    /// pool.scope(|s| {
377    ///     s.run(|_| {
378    ///         let _x = counter;
379    ///     });
380    ///     counter += 1;
381    /// });
382    ///
383    /// assert_eq!(counter, 1);
384    /// ```
385    fn _test_aliasing_borrows() {}
386
387    #[test]
388    #[should_panic]
389    fn test_panic_all() {
390        let mut pool = UnboundedThreadPool::new(4, "worker", false);
391
392        pool.scope(|s| {
393            s.run(|i| {
394                // all threads panic
395                panic!("{}", i);
396            });
397        });
398    }
399
400    #[test]
401    #[should_panic]
402    fn test_panic_single() {
403        let mut pool = UnboundedThreadPool::new(4, "worker", false);
404
405        pool.scope(|s| {
406            s.run(|i| {
407                // one thread panics
408                if i == 2 {
409                    panic!("{}", i);
410                }
411            });
412        });
413    }
414
415    // should not compile: "`x` does not live long enough"
416    /// ```compile_fail
417    /// # use shadow_rs::core::scheduler::pools::unbounded::*;
418    /// let mut pool = UnboundedThreadPool::new(4, "worker", false);
419    ///
420    /// let x = 5;
421    /// pool.scope(|s| {
422    ///     s.run(|_| {
423    ///         std::panic::panic_any(&x);
424    ///     });
425    /// });
426    /// ```
427    fn _test_panic_any() {}
428
429    // should not compile: "closure may outlive the current function, but it borrows `x`, which is
430    // owned by the current function"
431    /// ```compile_fail
432    /// # use shadow_rs::core::scheduler::pools::unbounded::*;
433    /// let mut pool = UnboundedThreadPool::new(4, "worker", false);
434    ///
435    /// pool.scope(|s| {
436    ///     // 'x' will be dropped when the closure is dropped, but 's' lives longer than that
437    ///     let x = 5;
438    ///     s.run(|_| {
439    ///         let _x = x;
440    ///     });
441    /// });
442    /// ```
443    fn _test_scope_lifetime() {}
444
445    #[test]
446    fn test_queues() {
447        let num_threads = 4;
448        let mut pool = UnboundedThreadPool::new(num_threads, "worker", false);
449
450        // a non-copy usize wrapper
451        struct Wrapper(usize);
452
453        let queues: Vec<_> = (0..num_threads)
454            .map(|_| crossbeam::queue::SegQueue::<Wrapper>::new())
455            .collect();
456
457        // queues[0] has Wrapper(0), queues[1] has Wrapper(1), etc
458        for (i, queue) in queues.iter().enumerate() {
459            queue.push(Wrapper(i));
460        }
461
462        let num_iters = 3;
463        for _ in 0..num_iters {
464            pool.scope(|s| {
465                s.run(|i: usize| {
466                    // take item from queue n and push it to queue n+1
467                    let wrapper = queues[i].pop().unwrap();
468                    queues[(i + 1) % num_threads].push(wrapper);
469                });
470            });
471        }
472
473        for (i, queue) in queues.iter().enumerate() {
474            assert_eq!(
475                queue.pop().unwrap().0,
476                i.wrapping_sub(num_iters) % num_threads
477            );
478        }
479    }
480}