scheduler/sync/
thread_parking.rs

1use std::sync::Arc;
2#[cfg(debug_assertions)]
3use std::sync::Mutex;
4use std::sync::atomic::{AtomicBool, Ordering};
5
6/// Used to unpark a thread, but which hasn't been assigned a specific thread yet.
7#[derive(Debug, Clone)]
8pub struct ThreadUnparkerUnassigned {
9    ready_flag: Arc<AtomicBool>,
10    /// The ID of the thread which is allowed to park.
11    #[cfg(debug_assertions)]
12    shared_thread_id: Arc<Mutex<Option<std::thread::ThreadId>>>,
13}
14
15/// Used to unpark a thread.
16#[derive(Debug, Clone)]
17pub struct ThreadUnparker {
18    thread: std::thread::Thread,
19    ready_flag: Arc<AtomicBool>,
20    /// The ID of the thread which is allowed to park.
21    #[cfg(debug_assertions)]
22    shared_thread_id: Arc<Mutex<Option<std::thread::ThreadId>>>,
23}
24
25/// Used to park a thread. The `ThreadParker` is derived from a `ThreadUnparker` or
26/// `ThreadUnparkerUnassigned`, and must only be used on the thread which the unparker was assigned
27/// to. If the `ThreadUnparker` was assigned to thread A, then `ThreadParker::park()` must only be
28/// called from thread A.
29#[derive(Debug, Clone)]
30pub struct ThreadParker {
31    ready_flag: Arc<AtomicBool>,
32    /// The ID of the thread which is allowed to park.
33    #[cfg(debug_assertions)]
34    shared_thread_id: Arc<Mutex<Option<std::thread::ThreadId>>>,
35}
36
37impl ThreadUnparkerUnassigned {
38    pub fn new() -> Self {
39        Self {
40            ready_flag: Arc::new(AtomicBool::new(false)),
41            // there is no assigned thread yet
42            #[cfg(debug_assertions)]
43            shared_thread_id: Arc::new(Mutex::new(None)),
44        }
45    }
46
47    /// Assign this to a thread that will be unparked.
48    #[must_use]
49    pub fn assign(self, thread: std::thread::Thread) -> ThreadUnparker {
50        ThreadUnparker::new(
51            self.ready_flag,
52            thread,
53            #[cfg(debug_assertions)]
54            self.shared_thread_id,
55        )
56    }
57
58    // we don't currently use this function, but I don't see a reason to delete it
59    #[allow(dead_code)]
60    /// Get a new [`ThreadParker`]. The `ThreadParker` must only be used from the thread which we
61    /// will later assign ourselves to using `assign()`. This is useful if you want to pass a
62    /// `ThreadParker` to a new thread before you have a handle to that thread.
63    pub fn parker(&self) -> ThreadParker {
64        ThreadParker::new(
65            Arc::clone(&self.ready_flag),
66            #[cfg(debug_assertions)]
67            Arc::clone(&self.shared_thread_id),
68        )
69    }
70}
71
72impl Default for ThreadUnparkerUnassigned {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl ThreadUnparker {
79    fn new(
80        ready_flag: Arc<AtomicBool>,
81        thread: std::thread::Thread,
82        #[cfg(debug_assertions)] shared_thread_id: Arc<Mutex<Option<std::thread::ThreadId>>>,
83    ) -> Self {
84        // set the value of `shared_thread_id`, or if it was already set, verify that it's the
85        // correct value
86        #[cfg(debug_assertions)]
87        {
88            let mut shared_thread_id = shared_thread_id.lock().unwrap();
89
90            // it's valid to park before the unparker has been assigned to a thread
91            // (`shared_thread_id` would be `Some` in this case), so if it was already set we should
92            // check that it is the correct thread
93            let shared_thread_id = shared_thread_id.get_or_insert_with(|| thread.id());
94
95            assert_eq!(
96                *shared_thread_id,
97                thread.id(),
98                "An earlier `ThreadParker::park()` was called from the wrong thread"
99            );
100        }
101
102        Self {
103            ready_flag,
104            thread,
105            #[cfg(debug_assertions)]
106            shared_thread_id,
107        }
108    }
109
110    /// Unpark the assigned thread.
111    pub fn unpark(&self) {
112        // NOTE: Rust now does guarantee some synchronization between the thread that parks and the
113        // thread that unparks, so the change to `ready_flag` should be seen by the parked thread:
114        //
115        // https://doc.rust-lang.org/std/thread/fn.park.html#memory-ordering
116        //
117        // > Calls to park synchronize-with calls to unpark, meaning that memory operations
118        // > performed before a call to unpark are made visible to the thread that consumes the
119        // > token and returns from park. Note that all park and unpark operations for a given
120        // > thread form a total order and park synchronizes-with all prior unpark operations.
121        // >
122        // > In atomic ordering terms, unpark performs a Release operation and park performs the
123        // > corresponding Acquire operation. Calls to unpark for the same thread form a release
124        // > sequence.
125        // >
126        // > Note that being unblocked does not imply a call was made to unpark, because wakeups can
127        // > also be spurious. For example, a valid, but inefficient, implementation could have park
128        // > and unpark return immediately without doing anything, making all wakeups spurious.
129        self.ready_flag.store(true, Ordering::Release);
130        self.thread.unpark();
131    }
132
133    /// Get a new [`ThreadParker`] for the assigned thread.
134    pub fn parker(&self) -> ThreadParker {
135        ThreadParker::new(
136            Arc::clone(&self.ready_flag),
137            #[cfg(debug_assertions)]
138            Arc::clone(&self.shared_thread_id),
139        )
140    }
141}
142
143impl ThreadParker {
144    fn new(
145        ready_flag: Arc<AtomicBool>,
146        #[cfg(debug_assertions)] shared_thread_id: Arc<Mutex<Option<std::thread::ThreadId>>>,
147    ) -> Self {
148        Self {
149            ready_flag,
150            #[cfg(debug_assertions)]
151            shared_thread_id,
152        }
153    }
154
155    /// Park the current thread until [`ThreadUnparker::unpark()`] is called. You must only call
156    /// `park()` from the thread which the corresponding `ThreadUnparker` is assigned, otherwise a
157    /// deadlock may occur. In debug builds, this should panic instead of deadlock.
158    pub fn park(&self) {
159        while self
160            .ready_flag
161            .compare_exchange(true, false, Ordering::Acquire, Ordering::Relaxed)
162            .is_err()
163        {
164            // verify that we're parking from the proper thread (only in debug builds since this is
165            // slow)
166            #[cfg(debug_assertions)]
167            {
168                let mut shared_thread_id = self.shared_thread_id.lock().unwrap();
169
170                // it's valid to park before the unparker has been assigned to a thread
171                // (`shared_thread_id` would be `None` in this case), so we should set the thread ID
172                // here and let the unparker panic instead if this is the wrong thread
173                let shared_thread_id =
174                    shared_thread_id.get_or_insert_with(|| std::thread::current().id());
175
176                assert_eq!(
177                    *shared_thread_id,
178                    std::thread::current().id(),
179                    "`ThreadParker::park()` was called from the wrong thread"
180                );
181            }
182
183            // if unpark() was called before this park(), this park() will return immediately
184            std::thread::park();
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_parking() {
195        let unparker = ThreadUnparkerUnassigned::new();
196        let parker = unparker.parker();
197
198        let handle = std::thread::spawn(move || {
199            parker.park();
200        });
201
202        let unparker = unparker.assign(handle.thread().clone());
203
204        // there is no race condition here: if `unpark` happens first, `park` will return
205        // immediately
206        unparker.unpark();
207
208        handle.join().unwrap();
209    }
210}