scheduler/pools/
bounded.rs

1// When comparing a loaded value that happens to be bool,
2// assert_eq! reads better than assert!.
3#![allow(clippy::bool_assert_comparison)]
4
5use std::marker::PhantomData;
6use std::ops::Deref;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9
10use atomic_refcell::AtomicRefCell;
11
12use crate::logical_processor::LogicalProcessors;
13use crate::sync::count_down_latch::{LatchCounter, LatchWaiter, build_count_down_latch};
14use crate::sync::thread_parking::{ThreadUnparker, ThreadUnparkerUnassigned};
15
16// If making substantial changes to this scheduler, you should verify the compilation error message
17// for each test at the end of this file to make sure that they correctly cause the expected
18// compilation error. This work pool unsafely transmutes the task closure lifetime, and the
19// commented tests are meant to make sure that the work pool does not allow unsound code to compile.
20// Due to lifetime sub-typing/variance, rust will sometimes allow closures with shorter or longer
21// lifetimes than we specify in the API, so the tests check to make sure the closures are invariant
22// over the lifetime and that the usage is sound.
23
24/// Context information provided to each task closure.
25pub struct TaskData {
26    pub thread_idx: usize,
27    pub processor_idx: usize,
28    pub cpu_id: Option<u32>,
29}
30
31/// A task that is run by the pool threads.
32trait TaskFn: Fn(&TaskData) + Send + Sync {}
33impl<T> TaskFn for T where T: Fn(&TaskData) + Send + Sync {}
34
35/// A thread pool that runs a task on many threads. A task will run once on each thread. Each
36/// logical processor will run threads sequentially, meaning that the thread pool's parallelism
37/// depends on the number of processors, not the number of threads. Threads are assigned to logical
38/// processors, which can be bound to operating system processors.
39pub struct ParallelismBoundedThreadPool {
40    /// Handles for joining threads when they've exited.
41    thread_handles: Vec<std::thread::JoinHandle<()>>,
42    /// State shared between all threads.
43    shared_state: Arc<SharedState>,
44    /// The main thread uses this to wait for the threads to finish running the task.
45    task_end_waiter: LatchWaiter,
46}
47
48pub struct SharedState {
49    /// The task to run during the next round.
50    task: AtomicRefCell<Option<Box<dyn TaskFn>>>,
51    /// Has a thread panicked?
52    has_thread_panicked: AtomicBool,
53    /// The logical processors.
54    logical_processors: AtomicRefCell<LogicalProcessors>,
55    /// The threads which run on logical processors.
56    threads: Vec<ThreadScheduling>,
57}
58
59/// Scheduling state for a thread.
60pub struct ThreadScheduling {
61    /// Used to unpark the thread when it has a new task.
62    unparker: ThreadUnparker,
63    /// The OS pid for this thread. This will have an invalid value when running under miri.
64    #[cfg_attr(miri, allow(dead_code))]
65    tid: nix::unistd::Pid,
66    /// The logical processor index that this thread is assigned to.
67    logical_processor_idx: AtomicUsize,
68}
69
70impl ParallelismBoundedThreadPool {
71    /// A new work pool with logical processors that are pinned to the provided OS processors.
72    /// Each logical processor is assigned many threads.
73    pub fn new(cpu_ids: &[Option<u32>], num_threads: usize, thread_name: &str) -> Self {
74        // we don't need more logical processors than threads
75        let cpu_ids = &cpu_ids[..std::cmp::min(cpu_ids.len(), num_threads)];
76
77        let logical_processors = LogicalProcessors::new(cpu_ids, num_threads);
78
79        let (task_end_counter, task_end_waiter) = build_count_down_latch();
80
81        let mut thread_handles = Vec::new();
82        let mut shared_state_senders = Vec::new();
83        let mut tids = Vec::new();
84
85        // start the threads
86        for i in 0..num_threads {
87            // the thread will send us the tid, then we'll later send the shared state to the thread
88            let (tid_send, tid_recv) = crossbeam::channel::bounded(1);
89            let (shared_state_send, shared_state_recv) = crossbeam::channel::bounded(1);
90
91            let task_end_counter_clone = task_end_counter.clone();
92
93            let handle = std::thread::Builder::new()
94                .name(thread_name.to_string())
95                .spawn(move || work_loop(i, tid_send, shared_state_recv, task_end_counter_clone))
96                .unwrap();
97
98            thread_handles.push(handle);
99            shared_state_senders.push(shared_state_send);
100            tids.push(tid_recv.recv().unwrap());
101        }
102
103        // build the scheduling data for the threads
104        let thread_data: Vec<ThreadScheduling> = logical_processors
105            .iter()
106            .cycle()
107            .zip(&tids)
108            .zip(&thread_handles)
109            .map(|((processor_idx, tid), handle)| ThreadScheduling {
110                unparker: ThreadUnparkerUnassigned::new().assign(handle.thread().clone()),
111                tid: *tid,
112                logical_processor_idx: AtomicUsize::new(processor_idx),
113            })
114            .collect();
115
116        // add each thread to its logical processor
117        for (thread_idx, thread) in thread_data.iter().enumerate() {
118            let logical_processor_idx = thread.logical_processor_idx.load(Ordering::Relaxed);
119            logical_processors.add_worker(logical_processor_idx, thread_idx);
120        }
121
122        // state shared between all threads
123        let shared_state = Arc::new(SharedState {
124            task: AtomicRefCell::new(None),
125            has_thread_panicked: AtomicBool::new(false),
126            logical_processors: AtomicRefCell::new(logical_processors),
127            threads: thread_data,
128        });
129
130        // send the shared state to each thread
131        for s in shared_state_senders.into_iter() {
132            s.send(Arc::clone(&shared_state)).unwrap();
133        }
134
135        Self {
136            thread_handles,
137            shared_state,
138            task_end_waiter,
139        }
140    }
141
142    /// The total number of logical processors.
143    pub fn num_processors(&self) -> usize {
144        self.shared_state.logical_processors.borrow().iter().len()
145    }
146
147    /// The total number of threads.
148    pub fn num_threads(&self) -> usize {
149        self.thread_handles.len()
150    }
151
152    /// Stop and join the threads.
153    pub fn join(self) {
154        // the drop handler will join the threads
155    }
156
157    fn join_internal(&mut self) {
158        // a `None` indicates that the threads should end
159        assert!(self.shared_state.task.borrow().is_none());
160
161        // only check the thread join return value if no threads have yet panicked
162        let check_for_errors = !self
163            .shared_state
164            .has_thread_panicked
165            .load(Ordering::Relaxed);
166
167        // send the sentinel task to all threads
168        for thread in &self.shared_state.threads {
169            thread.unparker.unpark();
170        }
171
172        for handle in self.thread_handles.drain(..) {
173            let result = handle.join();
174            if check_for_errors {
175                result.expect("A thread panicked while stopping");
176            }
177        }
178    }
179
180    /// Create a new scope for the pool. The scope will ensure that any task run on the pool within
181    /// this scope has completed before leaving the scope.
182    //
183    // SAFETY: This works because:
184    //
185    // 1. WorkerScope<'scope> is covariant over 'scope.
186    // 2. TaskRunner<'a, 'scope> is invariant over WorkerScope<'scope>, so TaskRunner<'a, 'scope>
187    //    is invariant over 'scope.
188    // 3. FnOnce(TaskRunner<'a, 'scope>) is contravariant over TaskRunner<'a, 'scope>, so
189    //    FnOnce(TaskRunner<'a, 'scope>) is invariant over 'scope.
190    //
191    // This means that the provided scope closure cannot take a TaskRunner<'a, 'scope2> where
192    // 'scope2 is shorter than 'scope, and therefore 'scope must be as long as this function call.
193    //
194    // If TaskRunner<'a, 'scope> was covariant over 'scope, then FnOnce(TaskRunner<'a, 'scope>)
195    // would have been contravariant over 'scope. This would have allowed the user to provide a
196    // scope closure that could take a TaskRunner<'a, 'scope2> where 'scope2 is shorter than 'scope.
197    // Then when TaskRunner<'a, 'scope2>::run(...) would eventually be called, the run closure would
198    // capture data with a lifetime of only 'scope2, which would be a shorter lifetime than the
199    // scope closure's lifetime of 'scope. Then, any captured mutable references would be accessible
200    // from both the run closure and the scope closure, leading to mutable aliasing.
201    pub fn scope<'scope>(
202        &'scope mut self,
203        f: impl for<'a> FnOnce(TaskRunner<'a, 'scope>) + 'scope,
204    ) {
205        assert!(
206            !self
207                .shared_state
208                .has_thread_panicked
209                .load(Ordering::Relaxed),
210            "Attempting to use a workpool that previously panicked"
211        );
212
213        // makes sure that the task is properly cleared even if 'f' panics
214        let mut scope = WorkerScope::<'scope> {
215            pool: self,
216            _phantom: Default::default(),
217        };
218
219        let runner = TaskRunner { scope: &mut scope };
220
221        f(runner);
222    }
223}
224
225impl std::ops::Drop for ParallelismBoundedThreadPool {
226    fn drop(&mut self) {
227        self.join_internal();
228    }
229}
230
231struct WorkerScope<'scope> {
232    pool: &'scope mut ParallelismBoundedThreadPool,
233    // when we are dropped, it's like dropping the task
234    _phantom: PhantomData<Box<dyn TaskFn + 'scope>>,
235}
236
237impl std::ops::Drop for WorkerScope<'_> {
238    fn drop(&mut self) {
239        // if the task was set (if `TaskRunner::run` was called)
240        if self.pool.shared_state.task.borrow().is_some() {
241            // wait for the task to complete
242            self.pool.task_end_waiter.wait();
243
244            // clear the task
245            *self.pool.shared_state.task.borrow_mut() = None;
246
247            // we should have run every thread, so swap the logical processors' internal queues
248            self.pool
249                .shared_state
250                .logical_processors
251                .borrow_mut()
252                .reset();
253
254            // generally following https://docs.rs/rayon/latest/rayon/fn.scope.html#panics
255            if self
256                .pool
257                .shared_state
258                .has_thread_panicked
259                .load(Ordering::Relaxed)
260            {
261                // we could store the thread's panic message and propagate it, but I don't think
262                // that's worth handling
263                panic!("A work thread panicked");
264            }
265        }
266    }
267}
268
269/// Allows a single task to run per pool scope.
270pub struct TaskRunner<'a, 'scope> {
271    // SAFETY: Self must be invariant over 'scope, which is why we use &mut here. See the
272    // documentation for scope() above for details.
273    scope: &'a mut WorkerScope<'scope>,
274}
275
276impl<'scope> TaskRunner<'_, 'scope> {
277    /// Run a task on the pool's threads.
278    // unfortunately we need to use `Fn(&TaskData) + Send + Sync` and not `TaskFn` here, otherwise
279    // rust's type inference doesn't work nicely in the calling code
280    pub fn run(self, f: impl Fn(&TaskData) + Send + Sync + 'scope) {
281        let f = Box::new(f);
282
283        // SAFETY: WorkerScope will drop this TaskFn before the end of 'scope
284        let f = unsafe {
285            std::mem::transmute::<Box<dyn TaskFn + 'scope>, Box<dyn TaskFn + 'static>>(f)
286        };
287
288        *self.scope.pool.shared_state.task.borrow_mut() = Some(f);
289
290        let logical_processors = self.scope.pool.shared_state.logical_processors.borrow();
291
292        // start the first thread for each logical processor
293        for processor_idx in logical_processors.iter() {
294            start_next_thread(
295                processor_idx,
296                &self.scope.pool.shared_state,
297                &logical_processors,
298            );
299        }
300    }
301}
302
303fn work_loop(
304    thread_idx: usize,
305    tid_send: crossbeam::channel::Sender<nix::unistd::Pid>,
306    shared_state_recv: crossbeam::channel::Receiver<Arc<SharedState>>,
307    mut end_counter: LatchCounter,
308) {
309    // we don't use `catch_unwind` here for two main reasons:
310    //
311    // 1. `catch_unwind` requires that the closure is `UnwindSafe`, which means that `TaskFn` also
312    // needs to be `UnwindSafe`. This is a big restriction on the types of tasks that we could run,
313    // since it requires that there's no interior mutability in the closure. rayon seems to get
314    // around this by wrapping the closure in `AssertUnwindSafe`, under the assumption that the
315    // panic will be propagated later with `resume_unwinding`, but this is a little more difficult
316    // to reason about compared to simply avoiding `catch_unwind` altogether.
317    // https://github.com/rayon-rs/rayon/blob/c571f8ffb4f74c8c09b4e1e6d9979b71b4414d07/rayon-core/src/unwind.rs#L9
318    //
319    // 2. There is a footgun with `catch_unwind` that could cause unexpected behaviour. If the
320    // closure called `panic_any()` with a type that has a Drop implementation, and that Drop
321    // implementation panics, it will cause a panic that is not caught by the `catch_unwind`,
322    // causing the thread to panic again with no chance to clean up properly. The work pool would
323    // then deadlock. Since we don't use `catch_unwind`, the thread will instead "panic when
324    // panicking" and abort, which is a more ideal outcome.
325    // https://github.com/rust-lang/rust/issues/86027
326
327    // this will poison the workpool when it's dropped
328    struct PoisonWhenDropped<'a>(&'a SharedState);
329
330    impl std::ops::Drop for PoisonWhenDropped<'_> {
331        fn drop(&mut self) {
332            // if we panicked, then inform other threads that we panicked and allow them to exit
333            // gracefully
334            self.0.has_thread_panicked.store(true, Ordering::Relaxed);
335        }
336    }
337
338    // this will start the next thread when it's dropped
339    struct StartNextThreadOnDrop<'a> {
340        shared_state: &'a SharedState,
341        logical_processors: &'a LogicalProcessors,
342        current_processor_idx: usize,
343    }
344
345    impl std::ops::Drop for StartNextThreadOnDrop<'_> {
346        fn drop(&mut self) {
347            start_next_thread(
348                self.current_processor_idx,
349                self.shared_state,
350                self.logical_processors,
351            );
352        }
353    }
354
355    let tid = if cfg!(not(miri)) {
356        nix::unistd::gettid()
357    } else {
358        // the sched_setaffinity() should be disabled under miri, so this should be fine
359        nix::unistd::Pid::from_raw(-1)
360    };
361
362    // send this thread's tid to the main thread
363    tid_send.send(tid).unwrap();
364
365    // get the shared state
366    let shared_state = shared_state_recv.recv().unwrap();
367    let shared_state = shared_state.as_ref();
368
369    let poison_when_dropped = PoisonWhenDropped(shared_state);
370
371    let thread_data = &shared_state.threads[thread_idx];
372    let thread_parker = thread_data.unparker.parker();
373
374    loop {
375        // wait for a new task
376        thread_parker.park();
377
378        // scope used to make sure we drop everything (including the task) before counting down
379        {
380            let logical_processors = &shared_state.logical_processors.borrow();
381
382            // the logical processor for this thread may have been changed by the previous thread if
383            // the thread was stolen from another logical processor
384            let current_processor_idx = thread_data.logical_processor_idx.load(Ordering::Relaxed);
385
386            // this will start the next thread even if the below task panics or we break from the
387            // loop
388            //
389            // we must start the next thread before we count down, otherwise we'll have runtime
390            // panics due to simultaneous exclusive and shared borrows of `logical_processors`
391            let _start_next_thread_when_dropped = StartNextThreadOnDrop {
392                shared_state,
393                logical_processors,
394                current_processor_idx,
395            };
396
397            // context information for the task
398            let task_data = TaskData {
399                thread_idx,
400                processor_idx: current_processor_idx,
401                cpu_id: logical_processors.cpu_id(current_processor_idx),
402            };
403
404            // run the task
405            match shared_state.task.borrow().deref() {
406                Some(task) => (task)(&task_data),
407                None => {
408                    // received the sentinel value
409                    break;
410                }
411            };
412        }
413
414        // SAFETY: we do not hold any references/borrows to the task at this time
415        end_counter.count_down();
416    }
417
418    // didn't panic, so forget the poison handler and return normally
419    std::mem::forget(poison_when_dropped);
420}
421
422/// Choose the next thread to run on the logical processor, and then start it.
423fn start_next_thread(
424    processor_idx: usize,
425    shared_state: &SharedState,
426    logical_processors: &LogicalProcessors,
427) {
428    // if there is a thread to run on this logical processor, then start it
429    if let Some((next_thread_idx, from_processor_idx)) =
430        logical_processors.next_worker(processor_idx)
431    {
432        let next_thread = &shared_state.threads[next_thread_idx];
433
434        debug_assert_eq!(
435            from_processor_idx,
436            next_thread.logical_processor_idx.load(Ordering::Relaxed)
437        );
438
439        // if the next thread is assigned to a different processor
440        if processor_idx != from_processor_idx {
441            assign_to_processor(next_thread, processor_idx, logical_processors);
442        }
443
444        // start the thread
445        next_thread.unparker.unpark();
446    }
447}
448
449/// Assigns the thread to the logical processor.
450fn assign_to_processor(
451    thread: &ThreadScheduling,
452    processor_idx: usize,
453    logical_processors: &LogicalProcessors,
454) {
455    // set thread's affinity if the logical processor has a cpu ID
456    if let Some(cpu_id) = logical_processors.cpu_id(processor_idx) {
457        let mut cpus = nix::sched::CpuSet::new();
458        cpus.set(cpu_id as usize).unwrap();
459
460        // only set the affinity if not running in miri
461        #[cfg(not(miri))]
462        nix::sched::sched_setaffinity(thread.tid, &cpus).unwrap();
463    }
464
465    // set thread's processor
466    thread
467        .logical_processor_idx
468        .store(processor_idx, Ordering::Release);
469}
470
471#[cfg(any(test, doctest))]
472mod tests {
473    use std::sync::atomic::{AtomicBool, AtomicU32};
474
475    use super::*;
476
477    #[test]
478    fn test_scope() {
479        let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
480
481        let mut counter = 0u32;
482        for _ in 0..3 {
483            pool.scope(|_| {
484                counter += 1;
485            });
486        }
487
488        assert_eq!(counter, 3);
489    }
490
491    #[test]
492    fn test_run() {
493        let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
494
495        let counter = AtomicU32::new(0);
496        for _ in 0..3 {
497            pool.scope(|s| {
498                s.run(|_| {
499                    counter.fetch_add(1, Ordering::SeqCst);
500                });
501            });
502        }
503
504        assert_eq!(counter.load(Ordering::SeqCst), 12);
505    }
506
507    #[test]
508    fn test_pinning() {
509        let mut pool = ParallelismBoundedThreadPool::new(&[Some(0), Some(1)], 4, "worker");
510
511        let counter = AtomicU32::new(0);
512        for _ in 0..3 {
513            pool.scope(|s| {
514                s.run(|_| {
515                    counter.fetch_add(1, Ordering::SeqCst);
516                });
517            });
518        }
519
520        assert_eq!(counter.load(Ordering::SeqCst), 12);
521    }
522
523    #[test]
524    fn test_large_parallelism() {
525        let mut pool = ParallelismBoundedThreadPool::new(&vec![None; 100], 4, "worker");
526
527        let counter = AtomicU32::new(0);
528        for _ in 0..3 {
529            pool.scope(|s| {
530                s.run(|_| {
531                    counter.fetch_add(1, Ordering::SeqCst);
532                });
533            });
534        }
535
536        assert_eq!(counter.load(Ordering::SeqCst), 12);
537    }
538
539    #[test]
540    fn test_large_num_threads() {
541        let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 100, "worker");
542
543        let counter = AtomicU32::new(0);
544        for _ in 0..3 {
545            pool.scope(|s| {
546                s.run(|_| {
547                    counter.fetch_add(1, Ordering::SeqCst);
548                });
549            });
550        }
551
552        assert_eq!(counter.load(Ordering::SeqCst), 300);
553    }
554
555    #[test]
556    fn test_scope_runner_order() {
557        let mut pool = ParallelismBoundedThreadPool::new(&[None], 1, "worker");
558
559        let flag = AtomicBool::new(false);
560        pool.scope(|s| {
561            s.run(|_| {
562                std::thread::sleep(std::time::Duration::from_millis(10));
563                flag.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
564                    .unwrap();
565            });
566            assert_eq!(flag.load(Ordering::SeqCst), false);
567        });
568
569        assert_eq!(flag.load(Ordering::SeqCst), true);
570    }
571
572    #[test]
573    fn test_non_aliasing_borrows() {
574        let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
575
576        let mut counter = 0;
577        pool.scope(|s| {
578            counter += 1;
579            s.run(|_| {
580                let _x = counter;
581            });
582        });
583
584        assert_eq!(counter, 1);
585    }
586
587    // should not compile: "cannot assign to `counter` because it is borrowed"
588    /// ```compile_fail
589    /// # use shadow_rs::core::scheduler::pools::bounded::*;
590    /// let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
591    ///
592    /// let mut counter = 0;
593    /// pool.scope(|s| {
594    ///     s.run(|_| {
595    ///         let _x = counter;
596    ///     });
597    ///     counter += 1;
598    /// });
599    ///
600    /// assert_eq!(counter, 1);
601    /// ```
602    fn _test_aliasing_borrows() {}
603
604    #[test]
605    #[should_panic]
606    fn test_panic_all() {
607        let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
608
609        pool.scope(|s| {
610            s.run(|t| {
611                // all threads panic
612                panic!("{}", t.thread_idx);
613            });
614        });
615    }
616
617    #[test]
618    #[should_panic]
619    fn test_panic_single() {
620        let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
621
622        pool.scope(|s| {
623            s.run(|t| {
624                // one thread panics
625                if t.thread_idx == 2 {
626                    panic!("{}", t.thread_idx);
627                }
628            });
629        });
630    }
631
632    // should not compile: "`x` does not live long enough"
633    /// ```compile_fail
634    /// # use shadow_rs::core::scheduler::pools::bounded::*;
635    /// let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
636    ///
637    /// let x = 5;
638    /// pool.scope(|s| {
639    ///     s.run(|_| {
640    ///         std::panic::panic_any(&x);
641    ///     });
642    /// });
643    /// ```
644    fn _test_panic_any() {}
645
646    // should not compile: "closure may outlive the current function, but it borrows `x`, which is
647    // owned by the current function"
648    /// ```compile_fail
649    /// # use shadow_rs::core::scheduler::pools::bounded::*;
650    /// let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
651    ///
652    /// pool.scope(|s| {
653    ///     // 'x' will be dropped when the closure is dropped, but 's' lives longer than that
654    ///     let x = 5;
655    ///     s.run(|_| {
656    ///         let _x = x;
657    ///     });
658    /// });
659    /// ```
660    fn _test_scope_lifetime() {}
661
662    #[test]
663    fn test_queues() {
664        let num_threads = 4;
665        let mut pool = ParallelismBoundedThreadPool::new(&[None, None], num_threads, "worker");
666
667        // a non-copy usize wrapper
668        struct Wrapper(usize);
669
670        let queues: Vec<_> = (0..num_threads)
671            .map(|_| crossbeam::queue::SegQueue::<Wrapper>::new())
672            .collect();
673
674        // queues[0] has Wrapper(0), queues[1] has Wrapper(1), etc
675        for (i, queue) in queues.iter().enumerate() {
676            queue.push(Wrapper(i));
677        }
678
679        let num_iters = 3;
680        for _ in 0..num_iters {
681            pool.scope(|s| {
682                s.run(|t| {
683                    // take item from queue n and push it to queue n+1
684                    let wrapper = queues[t.thread_idx].pop().unwrap();
685                    queues[(t.thread_idx + 1) % num_threads].push(wrapper);
686                });
687            });
688        }
689
690        for (i, queue) in queues.iter().enumerate() {
691            assert_eq!(
692                queue.pop().unwrap().0,
693                i.wrapping_sub(num_iters) % num_threads
694            );
695        }
696    }
697}