scheduler/pools/bounded.rs
1// When comparing a loaded value that happens to be bool,
2// assert_eq! reads better than assert!.
3#![allow(clippy::bool_assert_comparison)]
4
5use std::marker::PhantomData;
6use std::ops::Deref;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9
10use atomic_refcell::AtomicRefCell;
11
12use crate::logical_processor::LogicalProcessors;
13use crate::sync::count_down_latch::{LatchCounter, LatchWaiter, build_count_down_latch};
14use crate::sync::thread_parking::{ThreadUnparker, ThreadUnparkerUnassigned};
15
16// If making substantial changes to this scheduler, you should verify the compilation error message
17// for each test at the end of this file to make sure that they correctly cause the expected
18// compilation error. This work pool unsafely transmutes the task closure lifetime, and the
19// commented tests are meant to make sure that the work pool does not allow unsound code to compile.
20// Due to lifetime sub-typing/variance, rust will sometimes allow closures with shorter or longer
21// lifetimes than we specify in the API, so the tests check to make sure the closures are invariant
22// over the lifetime and that the usage is sound.
23
24/// Context information provided to each task closure.
25pub struct TaskData {
26 pub thread_idx: usize,
27 pub processor_idx: usize,
28 pub cpu_id: Option<u32>,
29}
30
31/// A task that is run by the pool threads.
32trait TaskFn: Fn(&TaskData) + Send + Sync {}
33impl<T> TaskFn for T where T: Fn(&TaskData) + Send + Sync {}
34
35/// A thread pool that runs a task on many threads. A task will run once on each thread. Each
36/// logical processor will run threads sequentially, meaning that the thread pool's parallelism
37/// depends on the number of processors, not the number of threads. Threads are assigned to logical
38/// processors, which can be bound to operating system processors.
39pub struct ParallelismBoundedThreadPool {
40 /// Handles for joining threads when they've exited.
41 thread_handles: Vec<std::thread::JoinHandle<()>>,
42 /// State shared between all threads.
43 shared_state: Arc<SharedState>,
44 /// The main thread uses this to wait for the threads to finish running the task.
45 task_end_waiter: LatchWaiter,
46}
47
48pub struct SharedState {
49 /// The task to run during the next round.
50 task: AtomicRefCell<Option<Box<dyn TaskFn>>>,
51 /// Has a thread panicked?
52 has_thread_panicked: AtomicBool,
53 /// The logical processors.
54 logical_processors: AtomicRefCell<LogicalProcessors>,
55 /// The threads which run on logical processors.
56 threads: Vec<ThreadScheduling>,
57}
58
59/// Scheduling state for a thread.
60pub struct ThreadScheduling {
61 /// Used to unpark the thread when it has a new task.
62 unparker: ThreadUnparker,
63 /// The OS pid for this thread. This will have an invalid value when running under miri.
64 #[cfg_attr(miri, allow(dead_code))]
65 tid: nix::unistd::Pid,
66 /// The logical processor index that this thread is assigned to.
67 logical_processor_idx: AtomicUsize,
68}
69
70impl ParallelismBoundedThreadPool {
71 /// A new work pool with logical processors that are pinned to the provided OS processors.
72 /// Each logical processor is assigned many threads.
73 pub fn new(cpu_ids: &[Option<u32>], num_threads: usize, thread_name: &str) -> Self {
74 // we don't need more logical processors than threads
75 let cpu_ids = &cpu_ids[..std::cmp::min(cpu_ids.len(), num_threads)];
76
77 let logical_processors = LogicalProcessors::new(cpu_ids, num_threads);
78
79 let (task_end_counter, task_end_waiter) = build_count_down_latch();
80
81 let mut thread_handles = Vec::new();
82 let mut shared_state_senders = Vec::new();
83 let mut tids = Vec::new();
84
85 // start the threads
86 for i in 0..num_threads {
87 // the thread will send us the tid, then we'll later send the shared state to the thread
88 let (tid_send, tid_recv) = crossbeam::channel::bounded(1);
89 let (shared_state_send, shared_state_recv) = crossbeam::channel::bounded(1);
90
91 let task_end_counter_clone = task_end_counter.clone();
92
93 let handle = std::thread::Builder::new()
94 .name(thread_name.to_string())
95 .spawn(move || work_loop(i, tid_send, shared_state_recv, task_end_counter_clone))
96 .unwrap();
97
98 thread_handles.push(handle);
99 shared_state_senders.push(shared_state_send);
100 tids.push(tid_recv.recv().unwrap());
101 }
102
103 // build the scheduling data for the threads
104 let thread_data: Vec<ThreadScheduling> = logical_processors
105 .iter()
106 .cycle()
107 .zip(&tids)
108 .zip(&thread_handles)
109 .map(|((processor_idx, tid), handle)| ThreadScheduling {
110 unparker: ThreadUnparkerUnassigned::new().assign(handle.thread().clone()),
111 tid: *tid,
112 logical_processor_idx: AtomicUsize::new(processor_idx),
113 })
114 .collect();
115
116 // add each thread to its logical processor
117 for (thread_idx, thread) in thread_data.iter().enumerate() {
118 let logical_processor_idx = thread.logical_processor_idx.load(Ordering::Relaxed);
119 logical_processors.add_worker(logical_processor_idx, thread_idx);
120 }
121
122 // state shared between all threads
123 let shared_state = Arc::new(SharedState {
124 task: AtomicRefCell::new(None),
125 has_thread_panicked: AtomicBool::new(false),
126 logical_processors: AtomicRefCell::new(logical_processors),
127 threads: thread_data,
128 });
129
130 // send the shared state to each thread
131 for s in shared_state_senders.into_iter() {
132 s.send(Arc::clone(&shared_state)).unwrap();
133 }
134
135 Self {
136 thread_handles,
137 shared_state,
138 task_end_waiter,
139 }
140 }
141
142 /// The total number of logical processors.
143 pub fn num_processors(&self) -> usize {
144 self.shared_state.logical_processors.borrow().iter().len()
145 }
146
147 /// The total number of threads.
148 pub fn num_threads(&self) -> usize {
149 self.thread_handles.len()
150 }
151
152 /// Stop and join the threads.
153 pub fn join(self) {
154 // the drop handler will join the threads
155 }
156
157 fn join_internal(&mut self) {
158 // a `None` indicates that the threads should end
159 assert!(self.shared_state.task.borrow().is_none());
160
161 // only check the thread join return value if no threads have yet panicked
162 let check_for_errors = !self
163 .shared_state
164 .has_thread_panicked
165 .load(Ordering::Relaxed);
166
167 // send the sentinel task to all threads
168 for thread in &self.shared_state.threads {
169 thread.unparker.unpark();
170 }
171
172 for handle in self.thread_handles.drain(..) {
173 let result = handle.join();
174 if check_for_errors {
175 result.expect("A thread panicked while stopping");
176 }
177 }
178 }
179
180 /// Create a new scope for the pool. The scope will ensure that any task run on the pool within
181 /// this scope has completed before leaving the scope.
182 //
183 // SAFETY: This works because:
184 //
185 // 1. WorkerScope<'scope> is covariant over 'scope.
186 // 2. TaskRunner<'a, 'scope> is invariant over WorkerScope<'scope>, so TaskRunner<'a, 'scope>
187 // is invariant over 'scope.
188 // 3. FnOnce(TaskRunner<'a, 'scope>) is contravariant over TaskRunner<'a, 'scope>, so
189 // FnOnce(TaskRunner<'a, 'scope>) is invariant over 'scope.
190 //
191 // This means that the provided scope closure cannot take a TaskRunner<'a, 'scope2> where
192 // 'scope2 is shorter than 'scope, and therefore 'scope must be as long as this function call.
193 //
194 // If TaskRunner<'a, 'scope> was covariant over 'scope, then FnOnce(TaskRunner<'a, 'scope>)
195 // would have been contravariant over 'scope. This would have allowed the user to provide a
196 // scope closure that could take a TaskRunner<'a, 'scope2> where 'scope2 is shorter than 'scope.
197 // Then when TaskRunner<'a, 'scope2>::run(...) would eventually be called, the run closure would
198 // capture data with a lifetime of only 'scope2, which would be a shorter lifetime than the
199 // scope closure's lifetime of 'scope. Then, any captured mutable references would be accessible
200 // from both the run closure and the scope closure, leading to mutable aliasing.
201 pub fn scope<'scope>(
202 &'scope mut self,
203 f: impl for<'a> FnOnce(TaskRunner<'a, 'scope>) + 'scope,
204 ) {
205 assert!(
206 !self
207 .shared_state
208 .has_thread_panicked
209 .load(Ordering::Relaxed),
210 "Attempting to use a workpool that previously panicked"
211 );
212
213 // makes sure that the task is properly cleared even if 'f' panics
214 let mut scope = WorkerScope::<'scope> {
215 pool: self,
216 _phantom: Default::default(),
217 };
218
219 let runner = TaskRunner { scope: &mut scope };
220
221 f(runner);
222 }
223}
224
225impl std::ops::Drop for ParallelismBoundedThreadPool {
226 fn drop(&mut self) {
227 self.join_internal();
228 }
229}
230
231struct WorkerScope<'scope> {
232 pool: &'scope mut ParallelismBoundedThreadPool,
233 // when we are dropped, it's like dropping the task
234 _phantom: PhantomData<Box<dyn TaskFn + 'scope>>,
235}
236
237impl std::ops::Drop for WorkerScope<'_> {
238 fn drop(&mut self) {
239 // if the task was set (if `TaskRunner::run` was called)
240 if self.pool.shared_state.task.borrow().is_some() {
241 // wait for the task to complete
242 self.pool.task_end_waiter.wait();
243
244 // clear the task
245 *self.pool.shared_state.task.borrow_mut() = None;
246
247 // we should have run every thread, so swap the logical processors' internal queues
248 self.pool
249 .shared_state
250 .logical_processors
251 .borrow_mut()
252 .reset();
253
254 // generally following https://docs.rs/rayon/latest/rayon/fn.scope.html#panics
255 if self
256 .pool
257 .shared_state
258 .has_thread_panicked
259 .load(Ordering::Relaxed)
260 {
261 // we could store the thread's panic message and propagate it, but I don't think
262 // that's worth handling
263 panic!("A work thread panicked");
264 }
265 }
266 }
267}
268
269/// Allows a single task to run per pool scope.
270pub struct TaskRunner<'a, 'scope> {
271 // SAFETY: Self must be invariant over 'scope, which is why we use &mut here. See the
272 // documentation for scope() above for details.
273 scope: &'a mut WorkerScope<'scope>,
274}
275
276impl<'scope> TaskRunner<'_, 'scope> {
277 /// Run a task on the pool's threads.
278 // unfortunately we need to use `Fn(&TaskData) + Send + Sync` and not `TaskFn` here, otherwise
279 // rust's type inference doesn't work nicely in the calling code
280 pub fn run(self, f: impl Fn(&TaskData) + Send + Sync + 'scope) {
281 let f = Box::new(f);
282
283 // SAFETY: WorkerScope will drop this TaskFn before the end of 'scope
284 let f = unsafe {
285 std::mem::transmute::<Box<dyn TaskFn + 'scope>, Box<dyn TaskFn + 'static>>(f)
286 };
287
288 *self.scope.pool.shared_state.task.borrow_mut() = Some(f);
289
290 let logical_processors = self.scope.pool.shared_state.logical_processors.borrow();
291
292 // start the first thread for each logical processor
293 for processor_idx in logical_processors.iter() {
294 start_next_thread(
295 processor_idx,
296 &self.scope.pool.shared_state,
297 &logical_processors,
298 );
299 }
300 }
301}
302
303fn work_loop(
304 thread_idx: usize,
305 tid_send: crossbeam::channel::Sender<nix::unistd::Pid>,
306 shared_state_recv: crossbeam::channel::Receiver<Arc<SharedState>>,
307 mut end_counter: LatchCounter,
308) {
309 // we don't use `catch_unwind` here for two main reasons:
310 //
311 // 1. `catch_unwind` requires that the closure is `UnwindSafe`, which means that `TaskFn` also
312 // needs to be `UnwindSafe`. This is a big restriction on the types of tasks that we could run,
313 // since it requires that there's no interior mutability in the closure. rayon seems to get
314 // around this by wrapping the closure in `AssertUnwindSafe`, under the assumption that the
315 // panic will be propagated later with `resume_unwinding`, but this is a little more difficult
316 // to reason about compared to simply avoiding `catch_unwind` altogether.
317 // https://github.com/rayon-rs/rayon/blob/c571f8ffb4f74c8c09b4e1e6d9979b71b4414d07/rayon-core/src/unwind.rs#L9
318 //
319 // 2. There is a footgun with `catch_unwind` that could cause unexpected behaviour. If the
320 // closure called `panic_any()` with a type that has a Drop implementation, and that Drop
321 // implementation panics, it will cause a panic that is not caught by the `catch_unwind`,
322 // causing the thread to panic again with no chance to clean up properly. The work pool would
323 // then deadlock. Since we don't use `catch_unwind`, the thread will instead "panic when
324 // panicking" and abort, which is a more ideal outcome.
325 // https://github.com/rust-lang/rust/issues/86027
326
327 // this will poison the workpool when it's dropped
328 struct PoisonWhenDropped<'a>(&'a SharedState);
329
330 impl std::ops::Drop for PoisonWhenDropped<'_> {
331 fn drop(&mut self) {
332 // if we panicked, then inform other threads that we panicked and allow them to exit
333 // gracefully
334 self.0.has_thread_panicked.store(true, Ordering::Relaxed);
335 }
336 }
337
338 // this will start the next thread when it's dropped
339 struct StartNextThreadOnDrop<'a> {
340 shared_state: &'a SharedState,
341 logical_processors: &'a LogicalProcessors,
342 current_processor_idx: usize,
343 }
344
345 impl std::ops::Drop for StartNextThreadOnDrop<'_> {
346 fn drop(&mut self) {
347 start_next_thread(
348 self.current_processor_idx,
349 self.shared_state,
350 self.logical_processors,
351 );
352 }
353 }
354
355 let tid = if cfg!(not(miri)) {
356 nix::unistd::gettid()
357 } else {
358 // the sched_setaffinity() should be disabled under miri, so this should be fine
359 nix::unistd::Pid::from_raw(-1)
360 };
361
362 // send this thread's tid to the main thread
363 tid_send.send(tid).unwrap();
364
365 // get the shared state
366 let shared_state = shared_state_recv.recv().unwrap();
367 let shared_state = shared_state.as_ref();
368
369 let poison_when_dropped = PoisonWhenDropped(shared_state);
370
371 let thread_data = &shared_state.threads[thread_idx];
372 let thread_parker = thread_data.unparker.parker();
373
374 loop {
375 // wait for a new task
376 thread_parker.park();
377
378 // scope used to make sure we drop everything (including the task) before counting down
379 {
380 let logical_processors = &shared_state.logical_processors.borrow();
381
382 // the logical processor for this thread may have been changed by the previous thread if
383 // the thread was stolen from another logical processor
384 let current_processor_idx = thread_data.logical_processor_idx.load(Ordering::Relaxed);
385
386 // this will start the next thread even if the below task panics or we break from the
387 // loop
388 //
389 // we must start the next thread before we count down, otherwise we'll have runtime
390 // panics due to simultaneous exclusive and shared borrows of `logical_processors`
391 let _start_next_thread_when_dropped = StartNextThreadOnDrop {
392 shared_state,
393 logical_processors,
394 current_processor_idx,
395 };
396
397 // context information for the task
398 let task_data = TaskData {
399 thread_idx,
400 processor_idx: current_processor_idx,
401 cpu_id: logical_processors.cpu_id(current_processor_idx),
402 };
403
404 // run the task
405 match shared_state.task.borrow().deref() {
406 Some(task) => (task)(&task_data),
407 None => {
408 // received the sentinel value
409 break;
410 }
411 };
412 }
413
414 // SAFETY: we do not hold any references/borrows to the task at this time
415 end_counter.count_down();
416 }
417
418 // didn't panic, so forget the poison handler and return normally
419 std::mem::forget(poison_when_dropped);
420}
421
422/// Choose the next thread to run on the logical processor, and then start it.
423fn start_next_thread(
424 processor_idx: usize,
425 shared_state: &SharedState,
426 logical_processors: &LogicalProcessors,
427) {
428 // if there is a thread to run on this logical processor, then start it
429 if let Some((next_thread_idx, from_processor_idx)) =
430 logical_processors.next_worker(processor_idx)
431 {
432 let next_thread = &shared_state.threads[next_thread_idx];
433
434 debug_assert_eq!(
435 from_processor_idx,
436 next_thread.logical_processor_idx.load(Ordering::Relaxed)
437 );
438
439 // if the next thread is assigned to a different processor
440 if processor_idx != from_processor_idx {
441 assign_to_processor(next_thread, processor_idx, logical_processors);
442 }
443
444 // start the thread
445 next_thread.unparker.unpark();
446 }
447}
448
449/// Assigns the thread to the logical processor.
450fn assign_to_processor(
451 thread: &ThreadScheduling,
452 processor_idx: usize,
453 logical_processors: &LogicalProcessors,
454) {
455 // set thread's affinity if the logical processor has a cpu ID
456 if let Some(cpu_id) = logical_processors.cpu_id(processor_idx) {
457 let mut cpus = nix::sched::CpuSet::new();
458 cpus.set(cpu_id as usize).unwrap();
459
460 // only set the affinity if not running in miri
461 #[cfg(not(miri))]
462 nix::sched::sched_setaffinity(thread.tid, &cpus).unwrap();
463 }
464
465 // set thread's processor
466 thread
467 .logical_processor_idx
468 .store(processor_idx, Ordering::Release);
469}
470
471#[cfg(any(test, doctest))]
472mod tests {
473 use std::sync::atomic::{AtomicBool, AtomicU32};
474
475 use super::*;
476
477 #[test]
478 fn test_scope() {
479 let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
480
481 let mut counter = 0u32;
482 for _ in 0..3 {
483 pool.scope(|_| {
484 counter += 1;
485 });
486 }
487
488 assert_eq!(counter, 3);
489 }
490
491 #[test]
492 fn test_run() {
493 let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
494
495 let counter = AtomicU32::new(0);
496 for _ in 0..3 {
497 pool.scope(|s| {
498 s.run(|_| {
499 counter.fetch_add(1, Ordering::SeqCst);
500 });
501 });
502 }
503
504 assert_eq!(counter.load(Ordering::SeqCst), 12);
505 }
506
507 #[test]
508 fn test_pinning() {
509 let mut pool = ParallelismBoundedThreadPool::new(&[Some(0), Some(1)], 4, "worker");
510
511 let counter = AtomicU32::new(0);
512 for _ in 0..3 {
513 pool.scope(|s| {
514 s.run(|_| {
515 counter.fetch_add(1, Ordering::SeqCst);
516 });
517 });
518 }
519
520 assert_eq!(counter.load(Ordering::SeqCst), 12);
521 }
522
523 #[test]
524 fn test_large_parallelism() {
525 let mut pool = ParallelismBoundedThreadPool::new(&vec![None; 100], 4, "worker");
526
527 let counter = AtomicU32::new(0);
528 for _ in 0..3 {
529 pool.scope(|s| {
530 s.run(|_| {
531 counter.fetch_add(1, Ordering::SeqCst);
532 });
533 });
534 }
535
536 assert_eq!(counter.load(Ordering::SeqCst), 12);
537 }
538
539 #[test]
540 fn test_large_num_threads() {
541 let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 100, "worker");
542
543 let counter = AtomicU32::new(0);
544 for _ in 0..3 {
545 pool.scope(|s| {
546 s.run(|_| {
547 counter.fetch_add(1, Ordering::SeqCst);
548 });
549 });
550 }
551
552 assert_eq!(counter.load(Ordering::SeqCst), 300);
553 }
554
555 #[test]
556 fn test_scope_runner_order() {
557 let mut pool = ParallelismBoundedThreadPool::new(&[None], 1, "worker");
558
559 let flag = AtomicBool::new(false);
560 pool.scope(|s| {
561 s.run(|_| {
562 std::thread::sleep(std::time::Duration::from_millis(10));
563 flag.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
564 .unwrap();
565 });
566 assert_eq!(flag.load(Ordering::SeqCst), false);
567 });
568
569 assert_eq!(flag.load(Ordering::SeqCst), true);
570 }
571
572 #[test]
573 fn test_non_aliasing_borrows() {
574 let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
575
576 let mut counter = 0;
577 pool.scope(|s| {
578 counter += 1;
579 s.run(|_| {
580 let _x = counter;
581 });
582 });
583
584 assert_eq!(counter, 1);
585 }
586
587 // should not compile: "cannot assign to `counter` because it is borrowed"
588 /// ```compile_fail
589 /// # use shadow_rs::core::scheduler::pools::bounded::*;
590 /// let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
591 ///
592 /// let mut counter = 0;
593 /// pool.scope(|s| {
594 /// s.run(|_| {
595 /// let _x = counter;
596 /// });
597 /// counter += 1;
598 /// });
599 ///
600 /// assert_eq!(counter, 1);
601 /// ```
602 fn _test_aliasing_borrows() {}
603
604 #[test]
605 #[should_panic]
606 fn test_panic_all() {
607 let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
608
609 pool.scope(|s| {
610 s.run(|t| {
611 // all threads panic
612 panic!("{}", t.thread_idx);
613 });
614 });
615 }
616
617 #[test]
618 #[should_panic]
619 fn test_panic_single() {
620 let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
621
622 pool.scope(|s| {
623 s.run(|t| {
624 // one thread panics
625 if t.thread_idx == 2 {
626 panic!("{}", t.thread_idx);
627 }
628 });
629 });
630 }
631
632 // should not compile: "`x` does not live long enough"
633 /// ```compile_fail
634 /// # use shadow_rs::core::scheduler::pools::bounded::*;
635 /// let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
636 ///
637 /// let x = 5;
638 /// pool.scope(|s| {
639 /// s.run(|_| {
640 /// std::panic::panic_any(&x);
641 /// });
642 /// });
643 /// ```
644 fn _test_panic_any() {}
645
646 // should not compile: "closure may outlive the current function, but it borrows `x`, which is
647 // owned by the current function"
648 /// ```compile_fail
649 /// # use shadow_rs::core::scheduler::pools::bounded::*;
650 /// let mut pool = ParallelismBoundedThreadPool::new(&[None, None], 4, "worker");
651 ///
652 /// pool.scope(|s| {
653 /// // 'x' will be dropped when the closure is dropped, but 's' lives longer than that
654 /// let x = 5;
655 /// s.run(|_| {
656 /// let _x = x;
657 /// });
658 /// });
659 /// ```
660 fn _test_scope_lifetime() {}
661
662 #[test]
663 fn test_queues() {
664 let num_threads = 4;
665 let mut pool = ParallelismBoundedThreadPool::new(&[None, None], num_threads, "worker");
666
667 // a non-copy usize wrapper
668 struct Wrapper(usize);
669
670 let queues: Vec<_> = (0..num_threads)
671 .map(|_| crossbeam::queue::SegQueue::<Wrapper>::new())
672 .collect();
673
674 // queues[0] has Wrapper(0), queues[1] has Wrapper(1), etc
675 for (i, queue) in queues.iter().enumerate() {
676 queue.push(Wrapper(i));
677 }
678
679 let num_iters = 3;
680 for _ in 0..num_iters {
681 pool.scope(|s| {
682 s.run(|t| {
683 // take item from queue n and push it to queue n+1
684 let wrapper = queues[t.thread_idx].pop().unwrap();
685 queues[(t.thread_idx + 1) % num_threads].push(wrapper);
686 });
687 });
688 }
689
690 for (i, queue) in queues.iter().enumerate() {
691 assert_eq!(
692 queue.pop().unwrap().0,
693 i.wrapping_sub(num_iters) % num_threads
694 );
695 }
696 }
697}