1#![forbid(unsafe_code)]
5
6use std::fmt::Debug;
7
8use crossbeam::queue::ArrayQueue;
9
10use crate::CORE_AFFINITY;
11use crate::pools::unbounded::{TaskRunner, UnboundedThreadPool};
12
13pub trait Host: Debug + Send {}
14impl<T> Host for T where T: Debug + Send {}
15
16pub struct ThreadPerCoreSched<HostType: Host> {
18 pool: UnboundedThreadPool,
19 num_threads: usize,
20 thread_hosts: Vec<ArrayQueue<HostType>>,
21 thread_hosts_processed: Vec<ArrayQueue<HostType>>,
22 hosts_need_swap: bool,
23}
24
25impl<HostType: Host> ThreadPerCoreSched<HostType> {
26 pub fn new<T>(cpu_ids: &[Option<u32>], hosts: T, yield_spin: bool) -> Self
30 where
31 T: IntoIterator<Item = HostType, IntoIter: ExactSizeIterator>,
32 {
33 let hosts = hosts.into_iter();
34
35 let num_threads = cpu_ids.len();
36 let mut pool = UnboundedThreadPool::new(num_threads, "shadow-worker", yield_spin);
37
38 pool.scope(|s| {
40 s.run(|i| {
41 let cpu_id = cpu_ids[i];
42
43 if let Some(cpu_id) = cpu_id {
44 let mut cpus = nix::sched::CpuSet::new();
45 cpus.set(cpu_id as usize).unwrap();
46 nix::sched::sched_setaffinity(nix::unistd::Pid::from_raw(0), &cpus).unwrap();
47
48 CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
50 }
51 });
52 });
53
54 let thread_hosts: Vec<_> = (0..num_threads)
56 .map(|_| ArrayQueue::new(hosts.len()))
57 .collect();
58 let thread_hosts_2: Vec<_> = (0..num_threads)
59 .map(|_| ArrayQueue::new(hosts.len()))
60 .collect();
61
62 for (thread_queue, host) in thread_hosts.iter().cycle().zip(hosts) {
64 thread_queue.push(host).unwrap();
65 }
66
67 Self {
68 pool,
69 num_threads,
70 thread_hosts,
71 thread_hosts_processed: thread_hosts_2,
72 hosts_need_swap: false,
73 }
74 }
75
76 pub fn parallelism(&self) -> usize {
78 self.num_threads
79 }
80
81 pub fn scope<'scope>(
83 &'scope mut self,
84 f: impl for<'a, 'b> FnOnce(SchedulerScope<'a, 'b, 'scope, HostType>) + 'scope,
85 ) {
86 if self.hosts_need_swap {
89 debug_assert!(self.thread_hosts.iter().all(|queue| queue.is_empty()));
90
91 std::mem::swap(&mut self.thread_hosts, &mut self.thread_hosts_processed);
92 self.hosts_need_swap = false;
93 }
94
95 let thread_hosts = &self.thread_hosts;
97 let thread_hosts_processed = &self.thread_hosts_processed;
98 let hosts_need_swap = &mut self.hosts_need_swap;
99
100 self.pool.scope(move |s| {
104 let sched_scope = SchedulerScope {
105 thread_hosts,
106 thread_hosts_processed,
107 hosts_need_swap,
108 runner: s,
109 };
110
111 (f)(sched_scope);
112 });
113 }
114
115 pub fn join(self) {
117 self.pool.join();
118 }
119}
120
121pub struct SchedulerScope<'sched, 'pool, 'scope, HostType: Host>
123where
124 'sched: 'scope,
125{
126 thread_hosts: &'sched Vec<ArrayQueue<HostType>>,
127 thread_hosts_processed: &'sched Vec<ArrayQueue<HostType>>,
128 hosts_need_swap: &'sched mut bool,
129 runner: TaskRunner<'pool, 'scope>,
130}
131
132#[allow(clippy::needless_lifetimes)]
134impl<'sched, 'pool, 'scope, HostType: Host> SchedulerScope<'sched, 'pool, 'scope, HostType> {
135 pub fn run(self, f: impl Fn(usize) + Sync + Send + 'scope) {
137 self.runner.run(f);
138 }
139
140 pub fn run_with_hosts(
142 self,
143 f: impl Fn(usize, &mut HostIter<'_, HostType>) + Send + Sync + 'scope,
144 ) {
145 self.runner.run(move |i| {
146 let mut host_iter = HostIter {
147 thread_hosts_from: self.thread_hosts,
148 thread_hosts_to: &self.thread_hosts_processed[i],
149 this_thread_index: i,
150 };
151
152 f(i, &mut host_iter);
153 });
154
155 *self.hosts_need_swap = true;
156 }
157
158 pub fn run_with_data<T>(
160 self,
161 data: &'scope [T],
162 f: impl Fn(usize, &mut HostIter<'_, HostType>, &T) + Send + Sync + 'scope,
163 ) where
164 T: Sync,
165 {
166 self.runner.run(move |i| {
167 let this_elem = &data[i];
168
169 let mut host_iter = HostIter {
170 thread_hosts_from: self.thread_hosts,
171 thread_hosts_to: &self.thread_hosts_processed[i],
172 this_thread_index: i,
173 };
174
175 f(i, &mut host_iter, this_elem);
176 });
177
178 *self.hosts_need_swap = true;
179 }
180}
181
182pub struct HostIter<'a, HostType: Host> {
185 thread_hosts_from: &'a [ArrayQueue<HostType>],
187 thread_hosts_to: &'a ArrayQueue<HostType>,
189 this_thread_index: usize,
192}
193
194impl<HostType: Host> HostIter<'_, HostType> {
195 pub fn for_each<F>(&mut self, mut f: F)
197 where
198 F: FnMut(HostType) -> HostType,
199 {
200 for from_queue in self
201 .thread_hosts_from
202 .iter()
203 .cycle()
204 .skip(self.this_thread_index)
206 .take(self.thread_hosts_from.len())
207 {
208 while let Some(host) = from_queue.pop() {
209 self.thread_hosts_to.push(f(host)).unwrap();
210 }
211 }
212 }
213}
214
215#[cfg(any(test, doctest))]
216mod tests {
217 use std::sync::atomic::{AtomicU32, Ordering};
218
219 use super::*;
220
221 #[derive(Debug)]
222 struct TestHost {}
223
224 #[test]
225 fn test_parallelism() {
226 let hosts = [(); 5].map(|_| TestHost {});
227 let sched: ThreadPerCoreSched<TestHost> =
228 ThreadPerCoreSched::new(&[None, None], hosts, false);
229
230 assert_eq!(sched.parallelism(), 2);
231
232 sched.join();
233 }
234
235 #[test]
236 fn test_no_join() {
237 let hosts = [(); 5].map(|_| TestHost {});
238 let _sched: ThreadPerCoreSched<TestHost> =
239 ThreadPerCoreSched::new(&[None, None], hosts, false);
240 }
241
242 #[test]
243 #[should_panic]
244 fn test_panic() {
245 let hosts = [(); 5].map(|_| TestHost {});
246 let mut sched: ThreadPerCoreSched<TestHost> =
247 ThreadPerCoreSched::new(&[None, None], hosts, false);
248
249 sched.scope(|s| {
250 s.run(|x| {
251 if x == 1 {
252 panic!();
253 }
254 });
255 });
256 }
257
258 #[test]
259 fn test_run() {
260 let hosts = [(); 5].map(|_| TestHost {});
261 let mut sched: ThreadPerCoreSched<TestHost> =
262 ThreadPerCoreSched::new(&[None, None], hosts, false);
263
264 let counter = AtomicU32::new(0);
265
266 for _ in 0..3 {
267 sched.scope(|s| {
268 s.run(|_| {
269 counter.fetch_add(1, Ordering::SeqCst);
270 });
271 });
272 }
273
274 assert_eq!(counter.load(Ordering::SeqCst), 2 * 3);
275
276 sched.join();
277 }
278
279 #[test]
280 fn test_run_with_hosts() {
281 let hosts = [(); 5].map(|_| TestHost {});
282 let mut sched: ThreadPerCoreSched<TestHost> =
283 ThreadPerCoreSched::new(&[None, None], hosts, false);
284
285 let counter = AtomicU32::new(0);
286
287 for _ in 0..3 {
288 sched.scope(|s| {
289 s.run_with_hosts(|_, hosts| {
290 hosts.for_each(|host| {
291 counter.fetch_add(1, Ordering::SeqCst);
292 host
293 });
294 });
295 });
296 }
297
298 assert_eq!(counter.load(Ordering::SeqCst), 5 * 3);
299
300 sched.join();
301 }
302
303 #[test]
304 fn test_run_with_data() {
305 let hosts = [(); 5].map(|_| TestHost {});
306 let mut sched: ThreadPerCoreSched<TestHost> =
307 ThreadPerCoreSched::new(&[None, None], hosts, false);
308
309 let data = vec![0u32; sched.parallelism()];
310 let data: Vec<_> = data.into_iter().map(std::sync::Mutex::new).collect();
311
312 for _ in 0..3 {
313 sched.scope(|s| {
314 s.run_with_data(&data, |_, hosts, elem| {
315 let mut elem = elem.lock().unwrap();
316 hosts.for_each(|host| {
317 *elem += 1;
318 host
319 });
320 });
321 });
322 }
323
324 let sum: u32 = data.into_iter().map(|x| x.into_inner().unwrap()).sum();
325 assert_eq!(sum, 5 * 3);
326
327 sched.join();
328 }
329}