scheduler/
thread_per_core.rs

1//! A thread-per-core host scheduler.
2
3// unsafe code should be isolated to the thread pool
4#![forbid(unsafe_code)]
5
6use std::fmt::Debug;
7
8use crossbeam::queue::ArrayQueue;
9
10use crate::CORE_AFFINITY;
11use crate::pools::unbounded::{TaskRunner, UnboundedThreadPool};
12
13pub trait Host: Debug + Send {}
14impl<T> Host for T where T: Debug + Send {}
15
16/// A host scheduler.
17pub struct ThreadPerCoreSched<HostType: Host> {
18    pool: UnboundedThreadPool,
19    num_threads: usize,
20    thread_hosts: Vec<ArrayQueue<HostType>>,
21    thread_hosts_processed: Vec<ArrayQueue<HostType>>,
22    hosts_need_swap: bool,
23}
24
25impl<HostType: Host> ThreadPerCoreSched<HostType> {
26    /// A new host scheduler with threads that are pinned to the provided OS processors. Each thread
27    /// is assigned many hosts, and threads may steal hosts from other threads. The number of
28    /// threads created will be the length of `cpu_ids`.
29    pub fn new<T>(cpu_ids: &[Option<u32>], hosts: T, yield_spin: bool) -> Self
30    where
31        T: IntoIterator<Item = HostType, IntoIter: ExactSizeIterator>,
32    {
33        let hosts = hosts.into_iter();
34
35        let num_threads = cpu_ids.len();
36        let mut pool = UnboundedThreadPool::new(num_threads, "shadow-worker", yield_spin);
37
38        // set the affinity of each thread
39        pool.scope(|s| {
40            s.run(|i| {
41                let cpu_id = cpu_ids[i];
42
43                if let Some(cpu_id) = cpu_id {
44                    let mut cpus = nix::sched::CpuSet::new();
45                    cpus.set(cpu_id as usize).unwrap();
46                    nix::sched::sched_setaffinity(nix::unistd::Pid::from_raw(0), &cpus).unwrap();
47
48                    // update the thread-local core affinity
49                    CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
50                }
51            });
52        });
53
54        // each thread gets two fixed-sized queues with enough capacity to store every host
55        let thread_hosts: Vec<_> = (0..num_threads)
56            .map(|_| ArrayQueue::new(hosts.len()))
57            .collect();
58        let thread_hosts_2: Vec<_> = (0..num_threads)
59            .map(|_| ArrayQueue::new(hosts.len()))
60            .collect();
61
62        // assign hosts to threads in a round-robin manner
63        for (thread_queue, host) in thread_hosts.iter().cycle().zip(hosts) {
64            thread_queue.push(host).unwrap();
65        }
66
67        Self {
68            pool,
69            num_threads,
70            thread_hosts,
71            thread_hosts_processed: thread_hosts_2,
72            hosts_need_swap: false,
73        }
74    }
75
76    /// See [`crate::Scheduler::parallelism`].
77    pub fn parallelism(&self) -> usize {
78        self.num_threads
79    }
80
81    /// See [`crate::Scheduler::scope`].
82    pub fn scope<'scope>(
83        &'scope mut self,
84        f: impl for<'a, 'b> FnOnce(SchedulerScope<'a, 'b, 'scope, HostType>) + 'scope,
85    ) {
86        // we can't swap after the below `pool.scope()` due to lifetime restrictions, so we need to
87        // do it before instead
88        if self.hosts_need_swap {
89            debug_assert!(self.thread_hosts.iter().all(|queue| queue.is_empty()));
90
91            std::mem::swap(&mut self.thread_hosts, &mut self.thread_hosts_processed);
92            self.hosts_need_swap = false;
93        }
94
95        // data/references that we'll pass to the scope
96        let thread_hosts = &self.thread_hosts;
97        let thread_hosts_processed = &self.thread_hosts_processed;
98        let hosts_need_swap = &mut self.hosts_need_swap;
99
100        // we cannot access `self` after calling `pool.scope()` since `SchedulerScope` has a
101        // lifetime of `'scope` (which at minimum spans the entire current function)
102
103        self.pool.scope(move |s| {
104            let sched_scope = SchedulerScope {
105                thread_hosts,
106                thread_hosts_processed,
107                hosts_need_swap,
108                runner: s,
109            };
110
111            (f)(sched_scope);
112        });
113    }
114
115    /// See [`crate::Scheduler::join`].
116    pub fn join(self) {
117        self.pool.join();
118    }
119}
120
121/// A wrapper around the work pool's scoped runner.
122pub struct SchedulerScope<'sched, 'pool, 'scope, HostType: Host>
123where
124    'sched: 'scope,
125{
126    thread_hosts: &'sched Vec<ArrayQueue<HostType>>,
127    thread_hosts_processed: &'sched Vec<ArrayQueue<HostType>>,
128    hosts_need_swap: &'sched mut bool,
129    runner: TaskRunner<'pool, 'scope>,
130}
131
132// there are multiple named lifetimes, so let's just be explicit about them rather than hide them
133#[allow(clippy::needless_lifetimes)]
134impl<'sched, 'pool, 'scope, HostType: Host> SchedulerScope<'sched, 'pool, 'scope, HostType> {
135    /// See [`crate::SchedulerScope::run`].
136    pub fn run(self, f: impl Fn(usize) + Sync + Send + 'scope) {
137        self.runner.run(f);
138    }
139
140    /// See [`crate::SchedulerScope::run_with_hosts`].
141    pub fn run_with_hosts(
142        self,
143        f: impl Fn(usize, &mut HostIter<'_, HostType>) + Send + Sync + 'scope,
144    ) {
145        self.runner.run(move |i| {
146            let mut host_iter = HostIter {
147                thread_hosts_from: self.thread_hosts,
148                thread_hosts_to: &self.thread_hosts_processed[i],
149                this_thread_index: i,
150            };
151
152            f(i, &mut host_iter);
153        });
154
155        *self.hosts_need_swap = true;
156    }
157
158    /// See [`crate::SchedulerScope::run_with_data`].
159    pub fn run_with_data<T>(
160        self,
161        data: &'scope [T],
162        f: impl Fn(usize, &mut HostIter<'_, HostType>, &T) + Send + Sync + 'scope,
163    ) where
164        T: Sync,
165    {
166        self.runner.run(move |i| {
167            let this_elem = &data[i];
168
169            let mut host_iter = HostIter {
170                thread_hosts_from: self.thread_hosts,
171                thread_hosts_to: &self.thread_hosts_processed[i],
172                this_thread_index: i,
173            };
174
175            f(i, &mut host_iter, this_elem);
176        });
177
178        *self.hosts_need_swap = true;
179    }
180}
181
182/// Supports iterating over all hosts assigned to this thread. For this thread-per-core scheduler,
183/// the iterator may steal hosts from other threads.
184pub struct HostIter<'a, HostType: Host> {
185    /// Queues to take hosts from.
186    thread_hosts_from: &'a [ArrayQueue<HostType>],
187    /// The queue to add hosts to when done with them.
188    thread_hosts_to: &'a ArrayQueue<HostType>,
189    /// The index of this thread. This is the first queue of `thread_hosts_from` that we take hosts
190    /// from.
191    this_thread_index: usize,
192}
193
194impl<HostType: Host> HostIter<'_, HostType> {
195    /// See [`crate::HostIter::for_each`].
196    pub fn for_each<F>(&mut self, mut f: F)
197    where
198        F: FnMut(HostType) -> HostType,
199    {
200        for from_queue in self
201            .thread_hosts_from
202            .iter()
203            .cycle()
204            // start from the current thread index
205            .skip(self.this_thread_index)
206            .take(self.thread_hosts_from.len())
207        {
208            while let Some(host) = from_queue.pop() {
209                self.thread_hosts_to.push(f(host)).unwrap();
210            }
211        }
212    }
213}
214
215#[cfg(any(test, doctest))]
216mod tests {
217    use std::sync::atomic::{AtomicU32, Ordering};
218
219    use super::*;
220
221    #[derive(Debug)]
222    struct TestHost {}
223
224    #[test]
225    fn test_parallelism() {
226        let hosts = [(); 5].map(|_| TestHost {});
227        let sched: ThreadPerCoreSched<TestHost> =
228            ThreadPerCoreSched::new(&[None, None], hosts, false);
229
230        assert_eq!(sched.parallelism(), 2);
231
232        sched.join();
233    }
234
235    #[test]
236    fn test_no_join() {
237        let hosts = [(); 5].map(|_| TestHost {});
238        let _sched: ThreadPerCoreSched<TestHost> =
239            ThreadPerCoreSched::new(&[None, None], hosts, false);
240    }
241
242    #[test]
243    #[should_panic]
244    fn test_panic() {
245        let hosts = [(); 5].map(|_| TestHost {});
246        let mut sched: ThreadPerCoreSched<TestHost> =
247            ThreadPerCoreSched::new(&[None, None], hosts, false);
248
249        sched.scope(|s| {
250            s.run(|x| {
251                if x == 1 {
252                    panic!();
253                }
254            });
255        });
256    }
257
258    #[test]
259    fn test_run() {
260        let hosts = [(); 5].map(|_| TestHost {});
261        let mut sched: ThreadPerCoreSched<TestHost> =
262            ThreadPerCoreSched::new(&[None, None], hosts, false);
263
264        let counter = AtomicU32::new(0);
265
266        for _ in 0..3 {
267            sched.scope(|s| {
268                s.run(|_| {
269                    counter.fetch_add(1, Ordering::SeqCst);
270                });
271            });
272        }
273
274        assert_eq!(counter.load(Ordering::SeqCst), 2 * 3);
275
276        sched.join();
277    }
278
279    #[test]
280    fn test_run_with_hosts() {
281        let hosts = [(); 5].map(|_| TestHost {});
282        let mut sched: ThreadPerCoreSched<TestHost> =
283            ThreadPerCoreSched::new(&[None, None], hosts, false);
284
285        let counter = AtomicU32::new(0);
286
287        for _ in 0..3 {
288            sched.scope(|s| {
289                s.run_with_hosts(|_, hosts| {
290                    hosts.for_each(|host| {
291                        counter.fetch_add(1, Ordering::SeqCst);
292                        host
293                    });
294                });
295            });
296        }
297
298        assert_eq!(counter.load(Ordering::SeqCst), 5 * 3);
299
300        sched.join();
301    }
302
303    #[test]
304    fn test_run_with_data() {
305        let hosts = [(); 5].map(|_| TestHost {});
306        let mut sched: ThreadPerCoreSched<TestHost> =
307            ThreadPerCoreSched::new(&[None, None], hosts, false);
308
309        let data = vec![0u32; sched.parallelism()];
310        let data: Vec<_> = data.into_iter().map(std::sync::Mutex::new).collect();
311
312        for _ in 0..3 {
313            sched.scope(|s| {
314                s.run_with_data(&data, |_, hosts, elem| {
315                    let mut elem = elem.lock().unwrap();
316                    hosts.for_each(|host| {
317                        *elem += 1;
318                        host
319                    });
320                });
321            });
322        }
323
324        let sum: u32 = data.into_iter().map(|x| x.into_inner().unwrap()).sum();
325        assert_eq!(sum, 5 * 3);
326
327        sched.join();
328    }
329}