scheduler/
thread_per_host.rs
1#![forbid(unsafe_code)]
5
6use std::cell::RefCell;
7use std::fmt::Debug;
8use std::sync::Mutex;
9use std::thread::LocalKey;
10
11use crate::CORE_AFFINITY;
12use crate::pools::bounded::{ParallelismBoundedThreadPool, TaskRunner};
13
14pub trait Host: Debug + Send + 'static {}
15impl<T> Host for T where T: Debug + Send + 'static {}
16
17pub struct ThreadPerHostSched<HostType: Host> {
19 pool: ParallelismBoundedThreadPool,
21 host_storage: &'static LocalKey<RefCell<Option<HostType>>>,
23}
24
25impl<HostType: Host> ThreadPerHostSched<HostType> {
26 pub fn new<T>(
34 cpu_ids: &[Option<u32>],
35 host_storage: &'static LocalKey<RefCell<Option<HostType>>>,
36 hosts: T,
37 ) -> Self
38 where
39 T: IntoIterator<Item = HostType, IntoIter: ExactSizeIterator>,
40 {
41 let hosts = hosts.into_iter();
42
43 let mut pool = ParallelismBoundedThreadPool::new(cpu_ids, hosts.len(), "shadow-worker");
44
45 let hosts: Vec<Mutex<Option<HostType>>> = hosts.map(|x| Mutex::new(Some(x))).collect();
47
48 pool.scope(|s| {
50 s.run(|t| {
51 host_storage.with(|x| {
52 assert!(x.borrow().is_none());
53 let host = hosts[t.thread_idx].lock().unwrap().take().unwrap();
54 *x.borrow_mut() = Some(host);
55 });
56 });
57 });
58
59 Self { pool, host_storage }
60 }
61
62 pub fn parallelism(&self) -> usize {
64 self.pool.num_processors()
65 }
66
67 pub fn scope<'scope>(
69 &'scope mut self,
70 f: impl for<'a> FnOnce(SchedulerScope<'a, 'scope, HostType>) + 'scope,
71 ) {
72 let host_storage = self.host_storage;
73 self.pool.scope(move |s| {
74 let sched_scope = SchedulerScope {
75 runner: s,
76 host_storage,
77 };
78
79 (f)(sched_scope);
80 });
81 }
82
83 pub fn join(mut self) {
85 let hosts: Vec<Mutex<Option<HostType>>> = (0..self.pool.num_threads())
86 .map(|_| Mutex::new(None))
87 .collect();
88
89 self.pool.scope(|s| {
91 s.run(|t| {
92 self.host_storage.with(|x| {
93 let host = x.borrow_mut().take().unwrap();
94 *hosts[t.thread_idx].lock().unwrap() = Some(host);
95 });
96 });
97 });
98
99 self.pool.join();
100 }
101}
102
103pub struct SchedulerScope<'pool, 'scope, HostType: Host> {
105 runner: TaskRunner<'pool, 'scope>,
107 host_storage: &'static LocalKey<RefCell<Option<HostType>>>,
109}
110
111#[allow(clippy::needless_lifetimes)]
113impl<'pool, 'scope, HostType: Host> SchedulerScope<'pool, 'scope, HostType> {
114 pub fn run(self, f: impl Fn(usize) + Sync + Send + 'scope) {
116 self.runner.run(move |task_context| {
117 if let Some(cpu_id) = task_context.cpu_id {
119 CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
120 }
121
122 (f)(task_context.thread_idx)
123 });
124 }
125
126 pub fn run_with_hosts(self, f: impl Fn(usize, &mut HostIter<HostType>) + Send + Sync + 'scope) {
128 self.runner.run(move |task_context| {
129 if let Some(cpu_id) = task_context.cpu_id {
131 CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
132 }
133
134 self.host_storage.with(|host| {
135 let mut host = host.borrow_mut();
136
137 let mut host_iter = HostIter { host: host.take() };
138
139 f(task_context.thread_idx, &mut host_iter);
140
141 host.replace(host_iter.host.take().unwrap());
142 });
143 });
144 }
145
146 pub fn run_with_data<T>(
148 self,
149 data: &'scope [T],
150 f: impl Fn(usize, &mut HostIter<HostType>, &T) + Send + Sync + 'scope,
151 ) where
152 T: Sync,
153 {
154 self.runner.run(move |task_context| {
155 if let Some(cpu_id) = task_context.cpu_id {
157 CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
158 }
159
160 let this_elem = &data[task_context.processor_idx];
161
162 self.host_storage.with(|host| {
163 let mut host = host.borrow_mut();
164
165 let mut host_iter = HostIter { host: host.take() };
166
167 f(task_context.thread_idx, &mut host_iter, this_elem);
168
169 host.replace(host_iter.host.unwrap());
170 });
171 });
172 }
173}
174
175pub struct HostIter<HostType: Host> {
178 host: Option<HostType>,
179}
180
181impl<HostType: Host> HostIter<HostType> {
182 pub fn for_each<F>(&mut self, mut f: F)
184 where
185 F: FnMut(HostType) -> HostType,
186 {
187 let host = self.host.take().unwrap();
188 self.host.replace(f(host));
189 }
190}
191
192#[cfg(any(test, doctest))]
193mod tests {
194 use std::cell::RefCell;
195 use std::sync::atomic::{AtomicU32, Ordering};
196
197 use super::*;
198
199 #[derive(Debug)]
200 struct TestHost {}
201
202 std::thread_local! {
203 static SCHED_HOST_STORAGE: RefCell<Option<TestHost>> = const { RefCell::new(None) };
204 }
205
206 #[test]
207 fn test_parallelism() {
208 let hosts = [(); 5].map(|_| TestHost {});
209 let sched: ThreadPerHostSched<TestHost> =
210 ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
211
212 assert_eq!(sched.parallelism(), 2);
213
214 sched.join();
215 }
216
217 #[test]
218 fn test_no_join() {
219 let hosts = [(); 5].map(|_| TestHost {});
220 let _sched: ThreadPerHostSched<TestHost> =
221 ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
222 }
223
224 #[test]
225 #[should_panic]
226 fn test_panic() {
227 let hosts = [(); 5].map(|_| TestHost {});
228 let mut sched: ThreadPerHostSched<TestHost> =
229 ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
230
231 sched.scope(|s| {
232 s.run(|x| {
233 if x == 1 {
234 panic!();
235 }
236 });
237 });
238 }
239
240 #[test]
241 fn test_run() {
242 let hosts = [(); 5].map(|_| TestHost {});
243 let mut sched: ThreadPerHostSched<TestHost> =
244 ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
245
246 let counter = AtomicU32::new(0);
247
248 for _ in 0..3 {
249 sched.scope(|s| {
250 s.run(|_| {
251 counter.fetch_add(1, Ordering::SeqCst);
252 });
253 });
254 }
255
256 assert_eq!(counter.load(Ordering::SeqCst), 5 * 3);
257
258 sched.join();
259 }
260
261 #[test]
262 fn test_run_with_hosts() {
263 let hosts = [(); 5].map(|_| TestHost {});
264 let mut sched: ThreadPerHostSched<TestHost> =
265 ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
266
267 let counter = AtomicU32::new(0);
268
269 for _ in 0..3 {
270 sched.scope(|s| {
271 s.run_with_hosts(|_, hosts| {
272 hosts.for_each(|host| {
273 counter.fetch_add(1, Ordering::SeqCst);
274 host
275 });
276 });
277 });
278 }
279
280 assert_eq!(counter.load(Ordering::SeqCst), 5 * 3);
281
282 sched.join();
283 }
284
285 #[test]
286 fn test_run_with_data() {
287 let hosts = [(); 5].map(|_| TestHost {});
288 let mut sched: ThreadPerHostSched<TestHost> =
289 ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
290
291 let data = vec![0u32; sched.parallelism()];
292 let data: Vec<_> = data.into_iter().map(std::sync::Mutex::new).collect();
293
294 for _ in 0..3 {
295 sched.scope(|s| {
296 s.run_with_data(&data, |_, hosts, elem| {
297 let mut elem = elem.lock().unwrap();
298 hosts.for_each(|host| {
299 *elem += 1;
300 host
301 });
302 });
303 });
304 }
305
306 let sum: u32 = data.into_iter().map(|x| x.into_inner().unwrap()).sum();
307 assert_eq!(sum, 5 * 3);
308
309 sched.join();
310 }
311}