#![forbid(unsafe_code)]
use std::cell::RefCell;
use std::fmt::Debug;
use std::sync::Mutex;
use std::thread::LocalKey;
use crate::pools::bounded::{ParallelismBoundedThreadPool, TaskRunner};
use crate::CORE_AFFINITY;
pub trait Host: Debug + Send + 'static {}
impl<T> Host for T where T: Debug + Send + 'static {}
pub struct ThreadPerHostSched<HostType: Host> {
pool: ParallelismBoundedThreadPool,
host_storage: &'static LocalKey<RefCell<Option<HostType>>>,
}
impl<HostType: Host> ThreadPerHostSched<HostType> {
pub fn new<T>(
cpu_ids: &[Option<u32>],
host_storage: &'static LocalKey<RefCell<Option<HostType>>>,
hosts: T,
) -> Self
where
T: IntoIterator<Item = HostType, IntoIter: ExactSizeIterator>,
{
let hosts = hosts.into_iter();
let mut pool = ParallelismBoundedThreadPool::new(cpu_ids, hosts.len(), "shadow-worker");
let hosts: Vec<Mutex<Option<HostType>>> = hosts.map(|x| Mutex::new(Some(x))).collect();
pool.scope(|s| {
s.run(|t| {
host_storage.with(|x| {
assert!(x.borrow().is_none());
let host = hosts[t.thread_idx].lock().unwrap().take().unwrap();
*x.borrow_mut() = Some(host);
});
});
});
Self { pool, host_storage }
}
pub fn parallelism(&self) -> usize {
self.pool.num_processors()
}
pub fn scope<'scope>(
&'scope mut self,
f: impl for<'a> FnOnce(SchedulerScope<'a, 'scope, HostType>) + 'scope,
) {
let host_storage = self.host_storage;
self.pool.scope(move |s| {
let sched_scope = SchedulerScope {
runner: s,
host_storage,
};
(f)(sched_scope);
});
}
pub fn join(mut self) {
let hosts: Vec<Mutex<Option<HostType>>> = (0..self.pool.num_threads())
.map(|_| Mutex::new(None))
.collect();
self.pool.scope(|s| {
s.run(|t| {
self.host_storage.with(|x| {
let host = x.borrow_mut().take().unwrap();
*hosts[t.thread_idx].lock().unwrap() = Some(host);
});
});
});
self.pool.join();
}
}
pub struct SchedulerScope<'pool, 'scope, HostType: Host> {
runner: TaskRunner<'pool, 'scope>,
host_storage: &'static LocalKey<RefCell<Option<HostType>>>,
}
impl<'pool, 'scope, HostType: Host> SchedulerScope<'pool, 'scope, HostType> {
pub fn run(self, f: impl Fn(usize) + Sync + Send + 'scope) {
self.runner.run(move |task_context| {
if let Some(cpu_id) = task_context.cpu_id {
CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
}
(f)(task_context.thread_idx)
});
}
pub fn run_with_hosts(self, f: impl Fn(usize, &mut HostIter<HostType>) + Send + Sync + 'scope) {
self.runner.run(move |task_context| {
if let Some(cpu_id) = task_context.cpu_id {
CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
}
self.host_storage.with(|host| {
let mut host = host.borrow_mut();
let mut host_iter = HostIter { host: host.take() };
f(task_context.thread_idx, &mut host_iter);
host.replace(host_iter.host.take().unwrap());
});
});
}
pub fn run_with_data<T>(
self,
data: &'scope [T],
f: impl Fn(usize, &mut HostIter<HostType>, &T) + Send + Sync + 'scope,
) where
T: Sync,
{
self.runner.run(move |task_context| {
if let Some(cpu_id) = task_context.cpu_id {
CORE_AFFINITY.with(|x| x.set(Some(cpu_id)));
}
let this_elem = &data[task_context.processor_idx];
self.host_storage.with(|host| {
let mut host = host.borrow_mut();
let mut host_iter = HostIter { host: host.take() };
f(task_context.thread_idx, &mut host_iter, this_elem);
host.replace(host_iter.host.unwrap());
});
});
}
}
pub struct HostIter<HostType: Host> {
host: Option<HostType>,
}
impl<HostType: Host> HostIter<HostType> {
pub fn for_each<F>(&mut self, mut f: F)
where
F: FnMut(HostType) -> HostType,
{
let host = self.host.take().unwrap();
self.host.replace(f(host));
}
}
#[cfg(any(test, doctest))]
mod tests {
use std::cell::RefCell;
use std::sync::atomic::{AtomicU32, Ordering};
use super::*;
#[derive(Debug)]
struct TestHost {}
std::thread_local! {
static SCHED_HOST_STORAGE: RefCell<Option<TestHost>> = const { RefCell::new(None) };
}
#[test]
fn test_parallelism() {
let hosts = [(); 5].map(|_| TestHost {});
let sched: ThreadPerHostSched<TestHost> =
ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
assert_eq!(sched.parallelism(), 2);
sched.join();
}
#[test]
fn test_no_join() {
let hosts = [(); 5].map(|_| TestHost {});
let _sched: ThreadPerHostSched<TestHost> =
ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
}
#[test]
#[should_panic]
fn test_panic() {
let hosts = [(); 5].map(|_| TestHost {});
let mut sched: ThreadPerHostSched<TestHost> =
ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
sched.scope(|s| {
s.run(|x| {
if x == 1 {
panic!();
}
});
});
}
#[test]
fn test_run() {
let hosts = [(); 5].map(|_| TestHost {});
let mut sched: ThreadPerHostSched<TestHost> =
ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
let counter = AtomicU32::new(0);
for _ in 0..3 {
sched.scope(|s| {
s.run(|_| {
counter.fetch_add(1, Ordering::SeqCst);
});
});
}
assert_eq!(counter.load(Ordering::SeqCst), 5 * 3);
sched.join();
}
#[test]
fn test_run_with_hosts() {
let hosts = [(); 5].map(|_| TestHost {});
let mut sched: ThreadPerHostSched<TestHost> =
ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
let counter = AtomicU32::new(0);
for _ in 0..3 {
sched.scope(|s| {
s.run_with_hosts(|_, hosts| {
hosts.for_each(|host| {
counter.fetch_add(1, Ordering::SeqCst);
host
});
});
});
}
assert_eq!(counter.load(Ordering::SeqCst), 5 * 3);
sched.join();
}
#[test]
fn test_run_with_data() {
let hosts = [(); 5].map(|_| TestHost {});
let mut sched: ThreadPerHostSched<TestHost> =
ThreadPerHostSched::new(&[None, None], &SCHED_HOST_STORAGE, hosts);
let data = vec![0u32; sched.parallelism()];
let data: Vec<_> = data.into_iter().map(std::sync::Mutex::new).collect();
for _ in 0..3 {
sched.scope(|s| {
s.run_with_data(&data, |_, hosts, elem| {
let mut elem = elem.lock().unwrap();
hosts.for_each(|host| {
*elem += 1;
host
});
});
});
}
let sum: u32 = data.into_iter().map(|x| x.into_inner().unwrap()).sum();
assert_eq!(sum, 5 * 3);
sched.join();
}
}