scheduler/pools/unbounded.rs
1// When comparing a loaded value that happens to be bool,
2// assert_eq! reads better than assert!.
3#![allow(clippy::bool_assert_comparison)]
4
5use std::marker::PhantomData;
6use std::ops::Deref;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use atomic_refcell::AtomicRefCell;
11
12use crate::sync::count_down_latch::{self, build_count_down_latch};
13use crate::sync::simple_latch;
14
15// If making substantial changes to this scheduler, you should verify the compilation error message
16// for each test at the end of this file to make sure that they correctly cause the expected
17// compilation error. This work pool unsafely transmutes the task closure lifetime, and the
18// commented tests are meant to make sure that the work pool does not allow unsound code to compile.
19// Due to lifetime sub-typing/variance, rust will sometimes allow closures with shorter or longer
20// lifetimes than we specify in the API, so the tests check to make sure the closures are invariant
21// over the lifetime and that the usage is sound.
22
23/// A task that is run by the pool threads.
24pub trait TaskFn: Fn(usize) + Send + Sync {}
25impl<T> TaskFn for T where T: Fn(usize) + Send + Sync {}
26
27/// A thread pool that runs a task on many threads. A task will run once on each thread.
28pub struct UnboundedThreadPool {
29 /// Handles for joining threads when they've exited.
30 thread_handles: Vec<std::thread::JoinHandle<()>>,
31 /// State shared between all threads.
32 shared_state: Arc<SharedState>,
33 /// A latch that is opened when the task is set. Indicates to the threads that they should start
34 /// running the task.
35 task_start_latch: simple_latch::Latch,
36 /// The main thread uses this to wait for the threads to finish running the task.
37 task_end_waiter: count_down_latch::LatchWaiter,
38}
39
40pub struct SharedState {
41 /// The task to run during the next round.
42 task: AtomicRefCell<Option<Box<dyn TaskFn>>>,
43 /// Has a thread panicked?
44 has_thread_panicked: AtomicBool,
45}
46
47impl UnboundedThreadPool {
48 pub fn new(num_threads: usize, thread_name: &str, yield_spin: bool) -> Self {
49 let shared_state = Arc::new(SharedState {
50 task: AtomicRefCell::new(None),
51 has_thread_panicked: AtomicBool::new(false),
52 });
53
54 let (task_end_counter, task_end_waiter) = build_count_down_latch();
55 let mut task_start_latch = simple_latch::Latch::new();
56
57 let mut thread_handles = Vec::new();
58
59 for i in 0..num_threads {
60 let shared_state_clone = Arc::clone(&shared_state);
61
62 // enabling spinning on the threads may improve performance under some conditions
63 // (see https://github.com/shadow/shadow/issues/2877)
64 let task_start_waiter = task_start_latch.waiter(yield_spin);
65
66 let task_end_counter_clone = task_end_counter.clone();
67
68 let handle = std::thread::Builder::new()
69 .name(thread_name.to_string())
70 .spawn(move || {
71 work_loop(
72 i,
73 shared_state_clone,
74 task_start_waiter,
75 task_end_counter_clone,
76 )
77 })
78 .unwrap();
79
80 thread_handles.push(handle);
81 }
82
83 Self {
84 thread_handles,
85 shared_state,
86 task_start_latch,
87 task_end_waiter,
88 }
89 }
90
91 /// Stop and join the threads.
92 pub fn join(self) {
93 // the drop handler will join the threads
94 }
95
96 fn join_internal(&mut self) {
97 // a `None` indicates that the threads should end
98 assert!(self.shared_state.task.borrow().is_none());
99
100 // only check the thread join return value if no threads have yet panicked
101 let check_for_errors = !self
102 .shared_state
103 .has_thread_panicked
104 .load(Ordering::Relaxed);
105
106 // start the threads
107 self.task_start_latch.open();
108
109 for handle in self.thread_handles.drain(..) {
110 let result = handle.join();
111 if check_for_errors {
112 result.expect("A thread panicked while stopping");
113 }
114 }
115 }
116
117 /// Create a new scope for the pool. The scope will ensure that any task run on the pool within
118 /// this scope has completed before leaving the scope.
119 //
120 // SAFETY: This works because:
121 //
122 // 1. WorkerScope<'scope> is covariant over 'scope.
123 // 2. TaskRunner<'a, 'scope> is invariant over WorkerScope<'scope>, so TaskRunner<'a, 'scope>
124 // is invariant over 'scope.
125 // 3. FnOnce(TaskRunner<'a, 'scope>) is contravariant over TaskRunner<'a, 'scope>, so
126 // FnOnce(TaskRunner<'a, 'scope>) is invariant over 'scope.
127 //
128 // This means that the provided scope closure cannot take a TaskRunner<'a, 'scope2> where
129 // 'scope2 is shorter than 'scope, and therefore 'scope must be as long as this function call.
130 //
131 // If TaskRunner<'a, 'scope> was covariant over 'scope, then FnOnce(TaskRunner<'a, 'scope>)
132 // would have been contravariant over 'scope. This would have allowed the user to provide a
133 // scope closure that could take a TaskRunner<'a, 'scope2> where 'scope2 is shorter than 'scope.
134 // Then when TaskRunner<'a, 'scope2>::run(...) would eventually be called, the run closure would
135 // capture data with a lifetime of only 'scope2, which would be a shorter lifetime than the
136 // scope closure's lifetime of 'scope. Then, any captured mutable references would be accessible
137 // from both the run closure and the scope closure, leading to mutable aliasing.
138 pub fn scope<'scope>(
139 &'scope mut self,
140 f: impl for<'a> FnOnce(TaskRunner<'a, 'scope>) + 'scope,
141 ) {
142 assert!(
143 !self
144 .shared_state
145 .has_thread_panicked
146 .load(Ordering::Relaxed),
147 "Attempting to use a workpool that previously panicked"
148 );
149
150 // makes sure that the task is properly cleared even if 'f' panics
151 let mut scope = WorkerScope::<'scope> {
152 pool: self,
153 _phantom: Default::default(),
154 };
155
156 let runner = TaskRunner { scope: &mut scope };
157
158 f(runner);
159 }
160}
161
162impl std::ops::Drop for UnboundedThreadPool {
163 fn drop(&mut self) {
164 self.join_internal();
165 }
166}
167
168struct WorkerScope<'scope> {
169 pool: &'scope mut UnboundedThreadPool,
170 // when we are dropped, it's like dropping the task
171 _phantom: PhantomData<Box<dyn TaskFn + 'scope>>,
172}
173
174impl std::ops::Drop for WorkerScope<'_> {
175 fn drop(&mut self) {
176 // if the task was set (if `TaskRunner::run` was called)
177 if self.pool.shared_state.task.borrow().is_some() {
178 // wait for the task to complete
179 self.pool.task_end_waiter.wait();
180
181 // clear the task
182 *self.pool.shared_state.task.borrow_mut() = None;
183
184 // generally following https://docs.rs/rayon/latest/rayon/fn.scope.html#panics
185 if self
186 .pool
187 .shared_state
188 .has_thread_panicked
189 .load(Ordering::Relaxed)
190 {
191 // we could store the thread's panic message and propagate it, but I don't think
192 // that's worth handling
193 panic!("A work thread panicked");
194 }
195 }
196 }
197}
198
199/// Allows a single task to run per pool scope.
200pub struct TaskRunner<'a, 'scope> {
201 // SAFETY: Self must be invariant over 'scope, which is why we use &mut here. See the
202 // documentation for scope() above for details.
203 scope: &'a mut WorkerScope<'scope>,
204}
205
206impl<'scope> TaskRunner<'_, 'scope> {
207 /// Run a task on the pool's threads.
208 pub fn run(self, f: impl TaskFn + 'scope) {
209 let f = Box::new(f);
210
211 // SAFETY: WorkerScope will drop this TaskFn before the end of 'scope
212 let f = unsafe {
213 std::mem::transmute::<Box<dyn TaskFn + 'scope>, Box<dyn TaskFn + 'static>>(f)
214 };
215
216 *self.scope.pool.shared_state.task.borrow_mut() = Some(f);
217
218 // we've set the task, so start the threads
219 self.scope.pool.task_start_latch.open();
220 }
221}
222
223fn work_loop(
224 thread_index: usize,
225 shared_state: Arc<SharedState>,
226 mut start_waiter: simple_latch::LatchWaiter,
227 mut end_counter: count_down_latch::LatchCounter,
228) {
229 // we don't use `catch_unwind` here for two main reasons:
230 //
231 // 1. `catch_unwind` requires that the closure is `UnwindSafe`, which means that `TaskFn` also
232 // needs to be `UnwindSafe`. This is a big restriction on the types of tasks that we could run,
233 // since it requires that there's no interior mutability in the closure. rayon seems to get
234 // around this by wrapping the closure in `AssertUnwindSafe`, under the assumption that the
235 // panic will be propagated later with `resume_unwinding`, but this is a little more difficult
236 // to reason about compared to simply avoiding `catch_unwind` altogether.
237 // https://github.com/rayon-rs/rayon/blob/c571f8ffb4f74c8c09b4e1e6d9979b71b4414d07/rayon-core/src/unwind.rs#L9
238 //
239 // 2. There is a footgun with `catch_unwind` that could cause unexpected behaviour. If the
240 // closure called `panic_any()` with a type that has a Drop implementation, and that Drop
241 // implementation panics, it will cause a panic that is not caught by the `catch_unwind`,
242 // causing the thread to panic again with no chance to clean up properly. The work pool would
243 // then deadlock. Since we don't use `catch_unwind`, the thread will instead "panic when
244 // panicking" and abort, which is a more ideal outcome.
245 // https://github.com/rust-lang/rust/issues/86027
246
247 // this will poison the workpool when it's dropped
248 struct PoisonWhenDropped<'a>(&'a SharedState);
249
250 impl std::ops::Drop for PoisonWhenDropped<'_> {
251 fn drop(&mut self) {
252 // if we panicked, then inform other threads that we panicked and allow them to exit
253 // gracefully
254 self.0.has_thread_panicked.store(true, Ordering::Relaxed);
255 }
256 }
257
258 let shared_state = shared_state.as_ref();
259 let poison_when_dropped = PoisonWhenDropped(shared_state);
260
261 loop {
262 // wait for a new task
263 start_waiter.wait();
264
265 // scope used to make sure we drop the task before counting down
266 {
267 // run the task
268 match shared_state.task.borrow().deref() {
269 Some(task) => (task)(thread_index),
270 None => {
271 // received the sentinel value
272 break;
273 }
274 };
275 }
276
277 // SAFETY: we do not hold any references/borrows to the task at this time
278 end_counter.count_down();
279 }
280
281 // didn't panic, so forget the poison handler and return normally
282 std::mem::forget(poison_when_dropped);
283}
284
285#[cfg(any(test, doctest))]
286mod tests {
287 use std::sync::atomic::{AtomicBool, AtomicU32};
288
289 use super::*;
290
291 #[test]
292 fn test_scope() {
293 let mut pool = UnboundedThreadPool::new(4, "worker", false);
294
295 let mut counter = 0u32;
296 for _ in 0..3 {
297 pool.scope(|_| {
298 counter += 1;
299 });
300 }
301
302 assert_eq!(counter, 3);
303 }
304
305 #[test]
306 fn test_run() {
307 let mut pool = UnboundedThreadPool::new(4, "worker", false);
308
309 let counter = AtomicU32::new(0);
310 for _ in 0..3 {
311 pool.scope(|s| {
312 s.run(|_| {
313 counter.fetch_add(1, Ordering::SeqCst);
314 });
315 });
316 }
317
318 assert_eq!(counter.load(Ordering::SeqCst), 12);
319 }
320
321 #[test]
322 fn test_large_num_threads() {
323 let mut pool = UnboundedThreadPool::new(100, "worker", false);
324
325 let counter = AtomicU32::new(0);
326 for _ in 0..3 {
327 pool.scope(|s| {
328 s.run(|_| {
329 counter.fetch_add(1, Ordering::SeqCst);
330 });
331 });
332 }
333
334 assert_eq!(counter.load(Ordering::SeqCst), 300);
335 }
336
337 #[test]
338 fn test_scope_runner_order() {
339 let mut pool = UnboundedThreadPool::new(1, "worker", false);
340
341 let flag = AtomicBool::new(false);
342 pool.scope(|s| {
343 s.run(|_| {
344 std::thread::sleep(std::time::Duration::from_millis(10));
345 flag.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
346 .unwrap();
347 });
348 assert_eq!(flag.load(Ordering::SeqCst), false);
349 });
350
351 assert_eq!(flag.load(Ordering::SeqCst), true);
352 }
353
354 #[test]
355 fn test_non_aliasing_borrows() {
356 let mut pool = UnboundedThreadPool::new(4, "worker", false);
357
358 let mut counter = 0;
359 pool.scope(|s| {
360 counter += 1;
361 s.run(|_| {
362 let _x = counter;
363 });
364 });
365
366 assert_eq!(counter, 1);
367 }
368
369 // should not compile: "cannot assign to `counter` because it is borrowed"
370 /// ```compile_fail
371 /// # use shadow_rs::core::scheduler::pools::unbounded::*;
372 /// let x = 5;
373 /// let mut pool = UnboundedThreadPool::new(4, "worker", false);
374 ///
375 /// let mut counter = 0;
376 /// pool.scope(|s| {
377 /// s.run(|_| {
378 /// let _x = counter;
379 /// });
380 /// counter += 1;
381 /// });
382 ///
383 /// assert_eq!(counter, 1);
384 /// ```
385 fn _test_aliasing_borrows() {}
386
387 #[test]
388 #[should_panic]
389 fn test_panic_all() {
390 let mut pool = UnboundedThreadPool::new(4, "worker", false);
391
392 pool.scope(|s| {
393 s.run(|i| {
394 // all threads panic
395 panic!("{}", i);
396 });
397 });
398 }
399
400 #[test]
401 #[should_panic]
402 fn test_panic_single() {
403 let mut pool = UnboundedThreadPool::new(4, "worker", false);
404
405 pool.scope(|s| {
406 s.run(|i| {
407 // one thread panics
408 if i == 2 {
409 panic!("{}", i);
410 }
411 });
412 });
413 }
414
415 // should not compile: "`x` does not live long enough"
416 /// ```compile_fail
417 /// # use shadow_rs::core::scheduler::pools::unbounded::*;
418 /// let mut pool = UnboundedThreadPool::new(4, "worker", false);
419 ///
420 /// let x = 5;
421 /// pool.scope(|s| {
422 /// s.run(|_| {
423 /// std::panic::panic_any(&x);
424 /// });
425 /// });
426 /// ```
427 fn _test_panic_any() {}
428
429 // should not compile: "closure may outlive the current function, but it borrows `x`, which is
430 // owned by the current function"
431 /// ```compile_fail
432 /// # use shadow_rs::core::scheduler::pools::unbounded::*;
433 /// let mut pool = UnboundedThreadPool::new(4, "worker", false);
434 ///
435 /// pool.scope(|s| {
436 /// // 'x' will be dropped when the closure is dropped, but 's' lives longer than that
437 /// let x = 5;
438 /// s.run(|_| {
439 /// let _x = x;
440 /// });
441 /// });
442 /// ```
443 fn _test_scope_lifetime() {}
444
445 #[test]
446 fn test_queues() {
447 let num_threads = 4;
448 let mut pool = UnboundedThreadPool::new(num_threads, "worker", false);
449
450 // a non-copy usize wrapper
451 struct Wrapper(usize);
452
453 let queues: Vec<_> = (0..num_threads)
454 .map(|_| crossbeam::queue::SegQueue::<Wrapper>::new())
455 .collect();
456
457 // queues[0] has Wrapper(0), queues[1] has Wrapper(1), etc
458 for (i, queue) in queues.iter().enumerate() {
459 queue.push(Wrapper(i));
460 }
461
462 let num_iters = 3;
463 for _ in 0..num_iters {
464 pool.scope(|s| {
465 s.run(|i: usize| {
466 // take item from queue n and push it to queue n+1
467 let wrapper = queues[i].pop().unwrap();
468 queues[(i + 1) % num_threads].push(wrapper);
469 });
470 });
471 }
472
473 for (i, queue) in queues.iter().enumerate() {
474 assert_eq!(
475 queue.pop().unwrap().0,
476 i.wrapping_sub(num_iters) % num_threads
477 );
478 }
479 }
480}