shadow_rs/utility/
childpid_watcher.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::Mutex;
4use std::thread;
5
6use linux_api::errno::Errno;
7use linux_api::posix_types::Pid;
8use rustix::event::{self, epoll};
9use rustix::fd::AsFd;
10use rustix::fd::OwnedFd;
11use rustix::io::FdFlags;
12use rustix::process::PidfdFlags;
13
14/// Utility for monitoring a set of child pid's, calling registered callbacks
15/// when one exits or is killed. Starts a background thread, which is shut down
16/// when the object is dropped.
17#[derive(Debug)]
18pub struct ChildPidWatcher {
19    inner: Arc<Mutex<Inner>>,
20    epoll: Arc<OwnedFd>,
21}
22
23pub type WatchHandle = u64;
24
25#[derive(Debug)]
26enum Command {
27    RunCallbacks(Pid),
28    UnregisterPid(Pid),
29    Finish,
30}
31
32struct PidData {
33    // Registered callbacks.
34    callbacks: HashMap<WatchHandle, Box<dyn Send + FnOnce(Pid)>>,
35    // After the pid has exited, this fd is closed and set to None.
36    pidfd: Option<OwnedFd>,
37    // Whether this pid has been unregistered. The whole struct is removed after
38    // both the pid is unregistered, and `callbacks` is empty.
39    unregistered: bool,
40}
41
42#[derive(Debug)]
43struct Inner {
44    // Next unique handle ID.
45    next_handle: WatchHandle,
46    // Pending commands for watcher thread.
47    commands: Vec<Command>,
48    // Data for each monitored pid.
49    pids: HashMap<Pid, PidData>,
50    // event_fd used to notify watcher thread via epoll. Calling thread writes a
51    // single byte, which the watcher thread reads to reset.
52    command_notifier: OwnedFd,
53    thread_handle: Option<thread::JoinHandle<()>>,
54}
55
56impl Inner {
57    fn send_command(&mut self, cmd: Command) {
58        self.commands.push(cmd);
59        rustix::io::write(&self.command_notifier, &1u64.to_ne_bytes()).unwrap();
60    }
61
62    fn unwatch_pid(&mut self, epoll: impl AsFd, pid: Pid) {
63        let Some(piddata) = self.pids.get_mut(&pid) else {
64            // Already unregistered the pid
65            return;
66        };
67        let Some(fd) = piddata.pidfd.take() else {
68            // Already unwatched the pid
69            return;
70        };
71        epoll::delete(epoll, fd).unwrap();
72    }
73
74    fn pid_has_exited(&self, pid: Pid) -> bool {
75        self.pids.get(&pid).unwrap().pidfd.is_none()
76    }
77
78    fn remove_pid(&mut self, epoll: impl AsFd, pid: Pid) {
79        debug_assert!(self.should_remove_pid(pid));
80        self.unwatch_pid(epoll, pid);
81        self.pids.remove(&pid);
82    }
83
84    fn run_callbacks_for_pid(&mut self, pid: Pid) {
85        for (_handle, cb) in self.pids.get_mut(&pid).unwrap().callbacks.drain() {
86            cb(pid)
87        }
88    }
89
90    fn should_remove_pid(&mut self, pid: Pid) -> bool {
91        let pid_data = self.pids.get(&pid).unwrap();
92        pid_data.callbacks.is_empty() && pid_data.unregistered
93    }
94
95    fn maybe_remove_pid(&mut self, epoll: impl AsFd, pid: Pid) {
96        if self.should_remove_pid(pid) {
97            self.remove_pid(epoll, pid)
98        }
99    }
100}
101
102impl ChildPidWatcher {
103    /// Create a ChildPidWatcher. Spawns a background thread, which is joined
104    /// when the object is dropped.
105    pub fn new() -> Self {
106        let epoll = Arc::new(epoll::create(epoll::CreateFlags::CLOEXEC).unwrap());
107        let command_notifier = event::eventfd(
108            0,
109            event::EventfdFlags::NONBLOCK | event::EventfdFlags::CLOEXEC,
110        )
111        .unwrap();
112        epoll::add(
113            &epoll,
114            &command_notifier,
115            epoll::EventData::new_u64(0),
116            epoll::EventFlags::IN,
117        )
118        .unwrap();
119        let watcher = ChildPidWatcher {
120            inner: Arc::new(Mutex::new(Inner {
121                next_handle: 1,
122                pids: HashMap::new(),
123                commands: Vec::new(),
124                command_notifier,
125                thread_handle: None,
126            })),
127            epoll,
128        };
129        let thread_handle = {
130            let inner = Arc::clone(&watcher.inner);
131            let epoll = watcher.epoll.clone();
132            thread::Builder::new()
133                .name("child-pid-watcher".into())
134                .spawn(move || ChildPidWatcher::thread_loop(&inner, &epoll))
135                .unwrap()
136        };
137        watcher.inner.lock().unwrap().thread_handle = Some(thread_handle);
138        watcher
139    }
140
141    fn thread_loop(inner: &Mutex<Inner>, epoll: impl AsFd) {
142        let mut commands = Vec::new();
143        let mut done = false;
144        while !done {
145            let mut events = epoll::EventVec::with_capacity(10);
146            match epoll::wait(epoll.as_fd(), &mut events, -1) {
147                Ok(()) => (),
148                Err(rustix::io::Errno::INTR) => {
149                    // Just try again.
150                    continue;
151                }
152                Err(e) => panic!("epoll_wait: {:?}", e),
153            };
154
155            // We hold the lock the whole time we're processing events. While it'd
156            // be nice to avoid holding it while executing callbacks (and therefore
157            // not require that callbacks don't call ChildPidWatcher APIs), that'd
158            // make it difficult to guarantee a callback *won't* be run if the
159            // caller unregisters it.
160            let mut inner = inner.lock().unwrap();
161
162            for event in events.into_iter() {
163                if event.data.u64() == 0 {
164                    // We get an event for pid=0 when there's a write to the
165                    // command_notifier; Ignore that here and handle below.
166                    continue;
167                }
168                let pid = Pid::from_raw(i32::try_from(event.data.u64()).unwrap()).unwrap();
169                inner.unwatch_pid(epoll.as_fd(), pid);
170                inner.run_callbacks_for_pid(pid);
171                inner.maybe_remove_pid(epoll.as_fd(), pid);
172            }
173            // Reading an eventfd always returns an 8 byte integer. Do so to ensure it's
174            // no longer marked 'readable'.
175            let mut buf = [0; 8];
176            let res = rustix::io::read(&inner.command_notifier, &mut buf);
177            debug_assert!(match res {
178                Ok(8) => true,
179                Ok(i) => panic!("Unexpected read size {}", i),
180                Err(rustix::io::Errno::AGAIN) => true,
181                Err(e) => panic!("Unexpected error {:?}", e),
182            });
183            // Run commands
184            std::mem::swap(&mut commands, &mut inner.commands);
185            for cmd in commands.drain(..) {
186                match cmd {
187                    Command::RunCallbacks(pid) => {
188                        debug_assert!(inner.pid_has_exited(pid));
189                        inner.run_callbacks_for_pid(pid);
190                        inner.maybe_remove_pid(epoll.as_fd(), pid);
191                    }
192                    Command::UnregisterPid(pid) => {
193                        if let Some(pid_data) = inner.pids.get_mut(&pid) {
194                            pid_data.unregistered = true;
195                            inner.maybe_remove_pid(epoll.as_fd(), pid);
196                        }
197                    }
198                    Command::Finish => {
199                        done = true;
200                        // There could be more commands queued and/or more epoll
201                        // events ready, but it doesn't matter. We don't
202                        // guarantee to callers whether callbacks have run or
203                        // not after having sent `Finish`; only that no more
204                        // callbacks will run after the thread is joined.
205                        break;
206                    }
207                }
208            }
209        }
210    }
211
212    /// Fork a child and register it. Uses `fork` internally; it `vfork` is desired,
213    /// use `register_pid` instead.
214    ///
215    /// Panics if `child_fn` returns.
216    /// TODO: change the type to `FnOnce() -> !` once that's stabilized in Rust.
217    /// <https://github.com/rust-lang/rust/issues/35121>
218    ///
219    /// # Safety
220    ///
221    /// As for fork in Rust in general. *Probably*, *mostly*, safe, since the
222    /// child process gets its own copy of the address space and OS resources etc.
223    /// Still, there may be some dragons here. Best to call exec before too long
224    /// in the child.
225    pub unsafe fn fork_watchable(&self, child_fn: impl FnOnce()) -> Result<Pid, Errno> {
226        let raw_pid = Errno::result_from_libc_errno(-1, unsafe { libc::syscall(libc::SYS_fork) })?;
227        if raw_pid == 0 {
228            child_fn();
229            panic!("child_fn shouldn't have returned");
230        }
231        let pid = Pid::from_raw(raw_pid.try_into().unwrap()).unwrap();
232        self.register_pid(pid);
233
234        Ok(pid)
235    }
236
237    /// Register interest in `pid`.
238    ///
239    /// Will succeed even if `pid` is already dead, in which case callbacks
240    /// registered for this `pid` will immediately be scheduled to run.
241    ///
242    /// `pid` must refer to some process, but that process may be a zombie (dead
243    /// but not yet reaped). Panics if `pid` doesn't exist at all.  The caller
244    /// should ensure the process has not been reaped before calling this
245    /// function both to avoid such panics, and to avoid accidentally watching
246    /// an unrelated process with a recycled `pid`.
247    pub fn register_pid(&self, pid: Pid) {
248        let mut inner = self.inner.lock().unwrap();
249        // We defensively make the pidfd non-blocking, since we intend to always
250        // use epoll to validate that it's ready before operating on it.
251        let pidfd = rustix::process::pidfd_open(pid.into(), PidfdFlags::NONBLOCK)
252            .unwrap_or_else(|e| panic!("pidfd_open failed for {pid:?}: {e:?}"));
253        // `pidfd_open(2)`: the close-on-exec flag is set on the file descriptor.
254        debug_assert!(
255            rustix::io::fcntl_getfd(&pidfd)
256                .unwrap()
257                .contains(FdFlags::CLOEXEC),
258            "pidfd_open unexpected didn't set CLOEXEC"
259        );
260        epoll::add(
261            &self.epoll,
262            &pidfd,
263            epoll::EventData::new_u64(pid.as_raw_nonzero().get().try_into().unwrap()),
264            epoll::EventFlags::IN,
265        )
266        .unwrap();
267
268        let prev = inner.pids.insert(
269            pid,
270            PidData {
271                callbacks: HashMap::new(),
272                pidfd: Some(pidfd),
273                unregistered: false,
274            },
275        );
276        assert!(prev.is_none());
277    }
278
279    // TODO: Re-enable when Rust supports vfork: https://github.com/rust-lang/rust/issues/58314
280    // pub unsafe fn vfork_watchable(&self, child_fn: impl FnOnce()) -> Result<Pid, nix::Error> {
281    //     unsafe { self.fork_watchable_internal(libc::SYS_vfork, child_fn) }
282    // }
283
284    /// Unregister the pid. After unregistration, no more callbacks may be
285    /// registered for the given pid. Already-registered callbacks will still be
286    /// called if and when the pid exits unless individually unregistered.
287    ///
288    /// Safe to call multiple times.
289    pub fn unregister_pid(&self, pid: Pid) {
290        // Let the worker handle the actual unregistration. This avoids a race
291        // where we unregister a pid at the same time as the worker thread
292        // receives an epoll event for it.
293        let mut inner = self.inner.lock().unwrap();
294        inner.send_command(Command::UnregisterPid(pid));
295    }
296
297    /// Call `callback` from another thread after the child `pid`
298    /// has exited, including if it has already exited. Does *not* reap the
299    /// child itself.
300    ///
301    /// The returned handle is guaranteed to be non-zero.
302    ///
303    /// Panics if `pid` isn't registered.
304    pub fn register_callback(
305        &self,
306        pid: Pid,
307        callback: impl Send + FnOnce(Pid) + 'static,
308    ) -> WatchHandle {
309        let mut inner = self.inner.lock().unwrap();
310        let handle = inner.next_handle;
311        inner.next_handle += 1;
312        let pid_data = inner.pids.get_mut(&pid).unwrap();
313        assert!(!pid_data.unregistered);
314        pid_data.callbacks.insert(handle, Box::new(callback));
315        if pid_data.pidfd.is_none() {
316            // pid is already dead. Run the callback we just registered.
317            inner.send_command(Command::RunCallbacks(pid));
318        }
319        handle
320    }
321
322    /// Unregisters a callback. After returning, the corresponding callback is
323    /// guaranteed either to have already run, or to never run. i.e. it's safe to
324    /// free data that the callback might otherwise access.
325    ///
326    /// No-op if `pid` isn't registered.
327    pub fn unregister_callback(&self, pid: Pid, handle: WatchHandle) {
328        let mut inner = self.inner.lock().unwrap();
329        if let Some(pid_data) = inner.pids.get_mut(&pid) {
330            pid_data.callbacks.remove(&handle);
331            inner.maybe_remove_pid(&self.epoll, pid);
332        }
333    }
334}
335
336impl Default for ChildPidWatcher {
337    fn default() -> Self {
338        Self::new()
339    }
340}
341
342impl Drop for ChildPidWatcher {
343    fn drop(&mut self) {
344        let handle = {
345            let mut inner = self.inner.lock().unwrap();
346            inner.send_command(Command::Finish);
347            inner.thread_handle.take().unwrap()
348        };
349        handle.join().unwrap();
350    }
351}
352
353impl std::fmt::Debug for PidData {
354    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        f.debug_struct("PidData")
356            .field("fd", &self.pidfd)
357            .field("unregistered", &self.unregistered)
358            .finish_non_exhaustive()
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use std::sync::{Arc, Condvar};
365
366    use nix::sys::eventfd::EventFd;
367    use rustix::fd::AsRawFd;
368    use rustix::process::{WaitOptions, waitpid};
369
370    use super::*;
371
372    fn is_zombie(pid: Pid) -> bool {
373        let stat_name = format!("/proc/{}/stat", pid.as_raw_nonzero().get());
374        let contents = std::fs::read_to_string(stat_name).unwrap();
375        contents.contains(") Z")
376    }
377
378    #[test]
379    // can't call foreign function: pipe
380    #[cfg_attr(miri, ignore)]
381    fn register_before_exit() {
382        let notifier = EventFd::new().unwrap();
383
384        let watcher = ChildPidWatcher::new();
385        let child = unsafe {
386            watcher.fork_watchable(|| {
387                let mut buf = [0; 8];
388                // Wait for parent to register its callback.
389                nix::unistd::read(notifier.as_raw_fd(), &mut buf).unwrap();
390                libc::_exit(42);
391            })
392        }
393        .unwrap();
394
395        let callback_ran = Arc::new((Mutex::new(false), Condvar::new()));
396        {
397            let callback_ran = callback_ran.clone();
398            watcher.register_callback(
399                child,
400                Box::new(move |pid| {
401                    assert_eq!(pid, child);
402                    *callback_ran.0.lock().unwrap() = true;
403                    callback_ran.1.notify_all();
404                }),
405            );
406        }
407
408        // Should be safe to unregister the pid now.
409        // We don't be able to register any more callbacks, but existing one
410        // should still work.
411        watcher.unregister_pid(child);
412
413        // Child should still be alive.
414        let status = waitpid(Some(child.into()), WaitOptions::NOHANG).unwrap();
415        assert!(status.is_none(), "Unexpected status: {status:?}");
416
417        // Callback shouldn't have run yet.
418        assert!(!*callback_ran.0.lock().unwrap());
419
420        // Let the child exit.
421        nix::unistd::write(&notifier, &1u64.to_ne_bytes()).unwrap();
422
423        // Wait for our callback to run.
424        let mut callback_ran_lock = callback_ran.0.lock().unwrap();
425        while !*callback_ran_lock {
426            callback_ran_lock = callback_ran.1.wait(callback_ran_lock).unwrap();
427        }
428
429        // Child should be ready to be reaped.
430        // TODO: use WNOHANG here if we go back to a pidfd-based implementation.
431        // With the current fd-based implementation we may be notified before kernel
432        // marks the child reapable.
433        let status = waitpid(Some(child.into()), WaitOptions::empty())
434            .unwrap()
435            .unwrap();
436        assert_eq!(status.exit_status(), Some(42));
437    }
438
439    #[test]
440    // can't call foreign functions
441    #[cfg_attr(miri, ignore)]
442    fn register_after_exit() {
443        let child = match unsafe { libc::fork() } {
444            0 => {
445                unsafe { libc::_exit(42) };
446            }
447            child => Pid::from_raw(child).unwrap(),
448        };
449
450        // Wait until child is dead, but don't reap it yet.
451        while !is_zombie(child) {
452            unsafe {
453                libc::sched_yield();
454            }
455        }
456
457        let watcher = ChildPidWatcher::new();
458        watcher.register_pid(child);
459
460        // Used to wait until after the ChildPidWatcher has ran our callback
461        let callback_ran = Arc::new((Mutex::new(false), Condvar::new()));
462        {
463            let callback_ran = callback_ran.clone();
464            watcher.register_callback(
465                child,
466                Box::new(move |pid| {
467                    assert_eq!(pid, child);
468                    *callback_ran.0.lock().unwrap() = true;
469                    callback_ran.1.notify_all();
470                }),
471            );
472        }
473
474        // Should be safe to unregister the pid now.
475        // We don't be able to register any more callbacks, but existing one
476        // should still work.
477        watcher.unregister_pid(child);
478
479        // Wait for our callback to run.
480        let mut callback_ran_lock = callback_ran.0.lock().unwrap();
481        while !*callback_ran_lock {
482            callback_ran_lock = callback_ran.1.wait(callback_ran_lock).unwrap();
483        }
484
485        // Child should be ready to be reaped.
486        // TODO: use WNOHANG here if we go back to a pidfd-based implementation.
487        // With the current fd-based implementation we may be notified before kernel
488        // marks the child reapable.
489        assert_eq!(
490            waitpid(Some(child.into()), WaitOptions::empty())
491                .unwrap()
492                .unwrap()
493                .exit_status(),
494            Some(42)
495        );
496    }
497
498    #[test]
499    // can't call foreign function: pipe
500    #[cfg_attr(miri, ignore)]
501    fn register_multiple() {
502        let cb1_ran = Arc::new((Mutex::new(false), Condvar::new()));
503        let cb2_ran = Arc::new((Mutex::new(false), Condvar::new()));
504
505        let watcher = ChildPidWatcher::new();
506        let child = unsafe {
507            watcher.fork_watchable(|| {
508                libc::_exit(42);
509            })
510        }
511        .unwrap();
512
513        for cb_ran in vec![cb1_ran.clone(), cb2_ran.clone()].drain(..) {
514            let cb_ran = cb_ran.clone();
515            watcher.register_callback(
516                child,
517                Box::new(move |pid| {
518                    assert_eq!(pid, child);
519                    *cb_ran.0.lock().unwrap() = true;
520                    cb_ran.1.notify_all();
521                }),
522            );
523        }
524
525        // Should be safe to unregister the pid now.
526        // We don't be able to register any more callbacks, but existing one
527        // should still work.
528        watcher.unregister_pid(child);
529
530        for cb_ran in vec![cb1_ran, cb2_ran].drain(..) {
531            let mut cb_ran_lock = cb_ran.0.lock().unwrap();
532            while !*cb_ran_lock {
533                cb_ran_lock = cb_ran.1.wait(cb_ran_lock).unwrap();
534            }
535        }
536
537        // Child should be ready to be reaped.
538        // TODO: use WNOHANG here if we go back to a pidfd-based implementation.
539        // With the current fd-based implementation we may be notified before kernel
540        // marks the child reapable.
541        assert_eq!(
542            waitpid(Some(child.into()), WaitOptions::empty())
543                .unwrap()
544                .unwrap()
545                .exit_status(),
546            Some(42)
547        );
548    }
549
550    #[test]
551    // can't call foreign function
552    #[cfg_attr(miri, ignore)]
553    fn unregister_one() {
554        let cb1_ran = Arc::new((Mutex::new(false), Condvar::new()));
555        let cb2_ran = Arc::new((Mutex::new(false), Condvar::new()));
556
557        let notifier = EventFd::new().unwrap();
558
559        let watcher = ChildPidWatcher::new();
560        let child = unsafe {
561            watcher.fork_watchable(|| {
562                let mut buf = [0; 8];
563                // Wait for parent to register its callback.
564                nix::unistd::read(notifier.as_raw_fd(), &mut buf).unwrap();
565                libc::_exit(42);
566            })
567        }
568        .unwrap();
569
570        let handles: Vec<WatchHandle> = [&cb1_ran, &cb2_ran]
571            .iter()
572            .cloned()
573            .map(|cb_ran| {
574                let cb_ran = cb_ran.clone();
575                watcher.register_callback(
576                    child,
577                    Box::new(move |pid| {
578                        assert_eq!(pid, child);
579                        *cb_ran.0.lock().unwrap() = true;
580                        cb_ran.1.notify_all();
581                    }),
582                )
583            })
584            .collect();
585
586        // Should be safe to unregister the pid now.
587        // We don't be able to register any more callbacks, but existing one
588        // should still work.
589        watcher.unregister_pid(child);
590
591        watcher.unregister_callback(child, handles[0]);
592
593        // Let the child exit.
594        nix::unistd::write(&notifier, &1u64.to_ne_bytes()).unwrap();
595
596        // Wait for the still-registered callback to run.
597        let mut cb_ran_lock = cb2_ran.0.lock().unwrap();
598        while !*cb_ran_lock {
599            cb_ran_lock = cb2_ran.1.wait(cb_ran_lock).unwrap();
600        }
601
602        // The unregistered cb should *not* have run.
603        assert!(!*cb1_ran.0.lock().unwrap());
604
605        // Child should be ready to be reaped.
606        // TODO: use WNOHANG here if we go back to a pidfd-based implementation.
607        // With the current fd-based implementation we may be notified before kernel
608        // marks the child reapable.
609        assert_eq!(
610            waitpid(Some(child.into()), WaitOptions::empty())
611                .unwrap()
612                .unwrap()
613                .exit_status(),
614            Some(42)
615        );
616    }
617}