scheduler/
thread_per_host.rs

1//! A thread-per-host host scheduler.
2
3// unsafe code should be isolated to the thread pool
4#![forbid(unsafe_code)]
5
6use std::cell::RefCell;
7use std::fmt::Debug;
8use std::sync::Mutex;
9use std::thread::LocalKey;
10
11use crate::CORE_AFFINITY;
12use crate::pools::bounded::{ParallelismBoundedThreadPool, TaskRunner};
13
14pub trait Host: Debug + Send + 'static {}
15impl<T> Host for T where T: Debug + Send + 'static {}
16
17/// A host scheduler.
18pub struct ThreadPerHostSched<HostType: Host> {
19    /// The thread pool.
20    pool: ParallelismBoundedThreadPool,
21    /// Thread-local storage where a thread can store its host.
22    host_storage: &'static LocalKey<RefCell<Option<HostType>>>,
23}
24
25impl<HostType: Host> ThreadPerHostSched<HostType> {
26    /// A new host scheduler with logical processors that are pinned to the provided OS processors.
27    /// Each logical processor is assigned many threads, and each thread is given a single host. The
28    /// number of threads created will be the length of `hosts`.
29    ///
30    /// An empty `host_storage` for thread-local storage is required for each thread to have
31    /// efficient access to its host. A panic may occur if `host_storage` is not `None`, or if it is
32    /// borrowed while the scheduler is in use.
33    pub fn new<T>(
34        cpu_ids: &[Option<u32>],
35        host_storage: &'static LocalKey<RefCell<Option<HostType>>>,
36        hosts: T,
37    ) -> Self
38    where
39        T: IntoIterator<Item = HostType, IntoIter: ExactSizeIterator>,
40    {
41        let hosts = hosts.into_iter();
42
43        let mut pool = ParallelismBoundedThreadPool::new(cpu_ids, hosts.len(), "shadow-worker");
44
45        // for determinism, threads will take hosts from a vec rather than a queue
46        let hosts: Vec<Mutex<Option<HostType>>> = hosts.map(|x| Mutex::new(Some(x))).collect();
47
48        // have each thread take a host and store it as a thread-local
49        pool.scope(|s| {
50            s.run(|t| {
51                host_storage.with(|x| {
52                    assert!(x.borrow().is_none());
53                    let host = hosts[t.thread_idx].lock().unwrap().take().unwrap();
54                    *x.borrow_mut() = Some(host);
55                });
56            });
57        });
58
59        Self { pool, host_storage }
60    }
61
62    /// See [`crate::Scheduler::parallelism`].
63    pub fn parallelism(&self) -> usize {
64        self.pool.num_processors()
65    }
66
67    /// See [`crate::Scheduler::scope`].
68    pub fn scope<'scope>(
69        &'scope mut self,
70        f: impl for<'a> FnOnce(SchedulerScope<'a, 'scope, HostType>) + 'scope,
71    ) {
72        let host_storage = self.host_storage;
73        self.pool.scope(move |s| {
74            let sched_scope = SchedulerScope {
75                runner: s,
76                host_storage,
77            };
78
79            (f)(sched_scope);
80        });
81    }
82
83    /// See [`crate::Scheduler::join`].
84    pub fn join(mut self) {
85        let hosts: Vec<Mutex<Option<HostType>>> = (0..self.pool.num_threads())
86            .map(|_| Mutex::new(None))
87            .collect();
88
89        // collect all of the hosts from the threads
90        self.pool.scope(|s| {
91            s.run(|t| {
92                self.host_storage.with(|x| {
93                    let host = x.borrow_mut().take().unwrap();
94                    *hosts[t.thread_idx].lock().unwrap() = Some(host);
95                });
96            });
97        });
98
99        self.pool.join();
100    }
101}
102
103/// A wrapper around the work pool's scoped runner.
104pub struct SchedulerScope<'pool, 'scope, HostType: Host> {
105    /// The work pool's scoped runner.
106    runner: TaskRunner<'pool, 'scope>,
107    /// Thread-local storage where a thread can retrieve its host.
108    host_storage: &'static LocalKey<RefCell<Option<HostType>>>,
109}
110
111// there are multiple named lifetimes, so let's just be explicit about them rather than hide them
112#[allow(clippy::needless_lifetimes)]
113impl<'pool, 'scope, HostType: Host> SchedulerScope<'pool, 'scope, HostType> {
114    /// See [`crate::SchedulerScope::run`].
115    pub fn run(self, f: impl Fn(usize) + Sync + Send + 'scope) {
116        self.runner.run(move |task_context| {
117            // update the thread-local core affinity
118            if let Some(cpu_id) = task_context.cpu_id {
119                CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
120            }
121
122            (f)(task_context.thread_idx)
123        });
124    }
125
126    /// See [`crate::SchedulerScope::run_with_hosts`].
127    pub fn run_with_hosts(self, f: impl Fn(usize, &mut HostIter<HostType>) + Send + Sync + 'scope) {
128        self.runner.run(move |task_context| {
129            // update the thread-local core affinity
130            if let Some(cpu_id) = task_context.cpu_id {
131                CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
132            }
133
134            self.host_storage.with(|host| {
135                let mut host = host.borrow_mut();
136
137                let mut host_iter = HostIter { host: host.take() };
138
139                f(task_context.thread_idx, &mut host_iter);
140
141                host.replace(host_iter.host.take().unwrap());
142            });
143        });
144    }
145
146    /// See [`crate::SchedulerScope::run_with_data`].
147    pub fn run_with_data<T>(
148        self,
149        data: &'scope [T],
150        f: impl Fn(usize, &mut HostIter<HostType>, &T) + Send + Sync + 'scope,
151    ) where
152        T: Sync,
153    {
154        self.runner.run(move |task_context| {
155            // update the thread-local core affinity
156            if let Some(cpu_id) = task_context.cpu_id {
157                CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
158            }
159
160            let this_elem = &data[task_context.processor_idx];
161
162            self.host_storage.with(|host| {
163                let mut host = host.borrow_mut();
164
165                let mut host_iter = HostIter { host: host.take() };
166
167                f(task_context.thread_idx, &mut host_iter, this_elem);
168
169                host.replace(host_iter.host.unwrap());
170            });
171        });
172    }
173}
174
175/// Supports iterating over all hosts assigned to this thread. For this thread-per-host scheduler,
176/// there will only ever be one host per thread.
177pub struct HostIter<HostType: Host> {
178    host: Option<HostType>,
179}
180
181impl<HostType: Host> HostIter<HostType> {
182    /// See [`crate::HostIter::for_each`].
183    pub fn for_each<F>(&mut self, mut f: F)
184    where
185        F: FnMut(HostType) -> HostType,
186    {
187        let host = self.host.take().unwrap();
188        self.host.replace(f(host));
189    }
190}
191
192#[cfg(any(test, doctest))]
193mod tests {
194    use std::cell::RefCell;
195    use std::sync::atomic::{AtomicU32, Ordering};
196
197    use super::*;
198
199    #[derive(Debug)]
200    struct TestHost {}
201
202    std::thread_local! {
203        static SCHED_HOST_STORAGE: RefCell<Option<TestHost>> = const { RefCell::new(None) };
204    }
205
206    #[test]
207    fn test_parallelism() {
208        let hosts = [(); 5].map(|_| TestHost {});
209        let sched: ThreadPerHostSched<TestHost> =
210            ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
211
212        assert_eq!(sched.parallelism(), 2);
213
214        sched.join();
215    }
216
217    #[test]
218    fn test_no_join() {
219        let hosts = [(); 5].map(|_| TestHost {});
220        let _sched: ThreadPerHostSched<TestHost> =
221            ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
222    }
223
224    #[test]
225    #[should_panic]
226    fn test_panic() {
227        let hosts = [(); 5].map(|_| TestHost {});
228        let mut sched: ThreadPerHostSched<TestHost> =
229            ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
230
231        sched.scope(|s| {
232            s.run(|x| {
233                if x == 1 {
234                    panic!();
235                }
236            });
237        });
238    }
239
240    #[test]
241    fn test_run() {
242        let hosts = [(); 5].map(|_| TestHost {});
243        let mut sched: ThreadPerHostSched<TestHost> =
244            ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
245
246        let counter = AtomicU32::new(0);
247
248        for _ in 0..3 {
249            sched.scope(|s| {
250                s.run(|_| {
251                    counter.fetch_add(1, Ordering::SeqCst);
252                });
253            });
254        }
255
256        assert_eq!(counter.load(Ordering::SeqCst), 5 * 3);
257
258        sched.join();
259    }
260
261    #[test]
262    fn test_run_with_hosts() {
263        let hosts = [(); 5].map(|_| TestHost {});
264        let mut sched: ThreadPerHostSched<TestHost> =
265            ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
266
267        let counter = AtomicU32::new(0);
268
269        for _ in 0..3 {
270            sched.scope(|s| {
271                s.run_with_hosts(|_, hosts| {
272                    hosts.for_each(|host| {
273                        counter.fetch_add(1, Ordering::SeqCst);
274                        host
275                    });
276                });
277            });
278        }
279
280        assert_eq!(counter.load(Ordering::SeqCst), 5 * 3);
281
282        sched.join();
283    }
284
285    #[test]
286    fn test_run_with_data() {
287        let hosts = [(); 5].map(|_| TestHost {});
288        let mut sched: ThreadPerHostSched<TestHost> =
289            ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
290
291        let data = vec![0u32; sched.parallelism()];
292        let data: Vec<_> = data.into_iter().map(std::sync::Mutex::new).collect();
293
294        for _ in 0..3 {
295            sched.scope(|s| {
296                s.run_with_data(&data, |_, hosts, elem| {
297                    let mut elem = elem.lock().unwrap();
298                    hosts.for_each(|host| {
299                        *elem += 1;
300                        host
301                    });
302                });
303            });
304        }
305
306        let sum: u32 = data.into_iter().map(|x| x.into_inner().unwrap()).sum();
307        assert_eq!(sum, 5 * 3);
308
309        sched.join();
310    }
311}