1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::Mutex;
4use std::thread;
5
6use linux_api::errno::Errno;
7use linux_api::posix_types::Pid;
8use rustix::event::{self, epoll};
9use rustix::fd::AsFd;
10use rustix::fd::OwnedFd;
11use rustix::io::FdFlags;
12use rustix::process::PidfdFlags;
13
14#[derive(Debug)]
18pub struct ChildPidWatcher {
19 inner: Arc<Mutex<Inner>>,
20 epoll: Arc<OwnedFd>,
21}
22
23pub type WatchHandle = u64;
24
25#[derive(Debug)]
26enum Command {
27 RunCallbacks(Pid),
28 UnregisterPid(Pid),
29 Finish,
30}
31
32struct PidData {
33 callbacks: HashMap<WatchHandle, Box<dyn Send + FnOnce(Pid)>>,
35 pidfd: Option<OwnedFd>,
37 unregistered: bool,
40}
41
42#[derive(Debug)]
43struct Inner {
44 next_handle: WatchHandle,
46 commands: Vec<Command>,
48 pids: HashMap<Pid, PidData>,
50 command_notifier: OwnedFd,
53 thread_handle: Option<thread::JoinHandle<()>>,
54}
55
56impl Inner {
57 fn send_command(&mut self, cmd: Command) {
58 self.commands.push(cmd);
59 rustix::io::write(&self.command_notifier, &1u64.to_ne_bytes()).unwrap();
60 }
61
62 fn unwatch_pid(&mut self, epoll: impl AsFd, pid: Pid) {
63 let Some(piddata) = self.pids.get_mut(&pid) else {
64 return;
66 };
67 let Some(fd) = piddata.pidfd.take() else {
68 return;
70 };
71 epoll::delete(epoll, fd).unwrap();
72 }
73
74 fn pid_has_exited(&self, pid: Pid) -> bool {
75 self.pids.get(&pid).unwrap().pidfd.is_none()
76 }
77
78 fn remove_pid(&mut self, epoll: impl AsFd, pid: Pid) {
79 debug_assert!(self.should_remove_pid(pid));
80 self.unwatch_pid(epoll, pid);
81 self.pids.remove(&pid);
82 }
83
84 fn run_callbacks_for_pid(&mut self, pid: Pid) {
85 for (_handle, cb) in self.pids.get_mut(&pid).unwrap().callbacks.drain() {
86 cb(pid)
87 }
88 }
89
90 fn should_remove_pid(&mut self, pid: Pid) -> bool {
91 let pid_data = self.pids.get(&pid).unwrap();
92 pid_data.callbacks.is_empty() && pid_data.unregistered
93 }
94
95 fn maybe_remove_pid(&mut self, epoll: impl AsFd, pid: Pid) {
96 if self.should_remove_pid(pid) {
97 self.remove_pid(epoll, pid)
98 }
99 }
100}
101
102impl ChildPidWatcher {
103 pub fn new() -> Self {
106 let epoll = Arc::new(epoll::create(epoll::CreateFlags::CLOEXEC).unwrap());
107 let command_notifier = event::eventfd(
108 0,
109 event::EventfdFlags::NONBLOCK | event::EventfdFlags::CLOEXEC,
110 )
111 .unwrap();
112 epoll::add(
113 &epoll,
114 &command_notifier,
115 epoll::EventData::new_u64(0),
116 epoll::EventFlags::IN,
117 )
118 .unwrap();
119 let watcher = ChildPidWatcher {
120 inner: Arc::new(Mutex::new(Inner {
121 next_handle: 1,
122 pids: HashMap::new(),
123 commands: Vec::new(),
124 command_notifier,
125 thread_handle: None,
126 })),
127 epoll,
128 };
129 let thread_handle = {
130 let inner = Arc::clone(&watcher.inner);
131 let epoll = watcher.epoll.clone();
132 thread::Builder::new()
133 .name("child-pid-watcher".into())
134 .spawn(move || ChildPidWatcher::thread_loop(&inner, &epoll))
135 .unwrap()
136 };
137 watcher.inner.lock().unwrap().thread_handle = Some(thread_handle);
138 watcher
139 }
140
141 fn thread_loop(inner: &Mutex<Inner>, epoll: impl AsFd) {
142 let mut commands = Vec::new();
143 let mut done = false;
144 while !done {
145 let mut events = epoll::EventVec::with_capacity(10);
146 match epoll::wait(epoll.as_fd(), &mut events, -1) {
147 Ok(()) => (),
148 Err(rustix::io::Errno::INTR) => {
149 continue;
151 }
152 Err(e) => panic!("epoll_wait: {:?}", e),
153 };
154
155 let mut inner = inner.lock().unwrap();
161
162 for event in events.into_iter() {
163 if event.data.u64() == 0 {
164 continue;
167 }
168 let pid = Pid::from_raw(i32::try_from(event.data.u64()).unwrap()).unwrap();
169 inner.unwatch_pid(epoll.as_fd(), pid);
170 inner.run_callbacks_for_pid(pid);
171 inner.maybe_remove_pid(epoll.as_fd(), pid);
172 }
173 let mut buf = [0; 8];
176 let res = rustix::io::read(&inner.command_notifier, &mut buf);
177 debug_assert!(match res {
178 Ok(8) => true,
179 Ok(i) => panic!("Unexpected read size {}", i),
180 Err(rustix::io::Errno::AGAIN) => true,
181 Err(e) => panic!("Unexpected error {:?}", e),
182 });
183 std::mem::swap(&mut commands, &mut inner.commands);
185 for cmd in commands.drain(..) {
186 match cmd {
187 Command::RunCallbacks(pid) => {
188 debug_assert!(inner.pid_has_exited(pid));
189 inner.run_callbacks_for_pid(pid);
190 inner.maybe_remove_pid(epoll.as_fd(), pid);
191 }
192 Command::UnregisterPid(pid) => {
193 if let Some(pid_data) = inner.pids.get_mut(&pid) {
194 pid_data.unregistered = true;
195 inner.maybe_remove_pid(epoll.as_fd(), pid);
196 }
197 }
198 Command::Finish => {
199 done = true;
200 break;
206 }
207 }
208 }
209 }
210 }
211
212 pub unsafe fn fork_watchable(&self, child_fn: impl FnOnce()) -> Result<Pid, Errno> {
226 let raw_pid = Errno::result_from_libc_errno(-1, unsafe { libc::syscall(libc::SYS_fork) })?;
227 if raw_pid == 0 {
228 child_fn();
229 panic!("child_fn shouldn't have returned");
230 }
231 let pid = Pid::from_raw(raw_pid.try_into().unwrap()).unwrap();
232 self.register_pid(pid);
233
234 Ok(pid)
235 }
236
237 pub fn register_pid(&self, pid: Pid) {
248 let mut inner = self.inner.lock().unwrap();
249 let pidfd = rustix::process::pidfd_open(pid.into(), PidfdFlags::NONBLOCK)
252 .unwrap_or_else(|e| panic!("pidfd_open failed for {pid:?}: {e:?}"));
253 debug_assert!(
255 rustix::io::fcntl_getfd(&pidfd)
256 .unwrap()
257 .contains(FdFlags::CLOEXEC),
258 "pidfd_open unexpected didn't set CLOEXEC"
259 );
260 epoll::add(
261 &self.epoll,
262 &pidfd,
263 epoll::EventData::new_u64(pid.as_raw_nonzero().get().try_into().unwrap()),
264 epoll::EventFlags::IN,
265 )
266 .unwrap();
267
268 let prev = inner.pids.insert(
269 pid,
270 PidData {
271 callbacks: HashMap::new(),
272 pidfd: Some(pidfd),
273 unregistered: false,
274 },
275 );
276 assert!(prev.is_none());
277 }
278
279 pub fn unregister_pid(&self, pid: Pid) {
290 let mut inner = self.inner.lock().unwrap();
294 inner.send_command(Command::UnregisterPid(pid));
295 }
296
297 pub fn register_callback(
305 &self,
306 pid: Pid,
307 callback: impl Send + FnOnce(Pid) + 'static,
308 ) -> WatchHandle {
309 let mut inner = self.inner.lock().unwrap();
310 let handle = inner.next_handle;
311 inner.next_handle += 1;
312 let pid_data = inner.pids.get_mut(&pid).unwrap();
313 assert!(!pid_data.unregistered);
314 pid_data.callbacks.insert(handle, Box::new(callback));
315 if pid_data.pidfd.is_none() {
316 inner.send_command(Command::RunCallbacks(pid));
318 }
319 handle
320 }
321
322 pub fn unregister_callback(&self, pid: Pid, handle: WatchHandle) {
328 let mut inner = self.inner.lock().unwrap();
329 if let Some(pid_data) = inner.pids.get_mut(&pid) {
330 pid_data.callbacks.remove(&handle);
331 inner.maybe_remove_pid(&self.epoll, pid);
332 }
333 }
334}
335
336impl Default for ChildPidWatcher {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342impl Drop for ChildPidWatcher {
343 fn drop(&mut self) {
344 let handle = {
345 let mut inner = self.inner.lock().unwrap();
346 inner.send_command(Command::Finish);
347 inner.thread_handle.take().unwrap()
348 };
349 handle.join().unwrap();
350 }
351}
352
353impl std::fmt::Debug for PidData {
354 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 f.debug_struct("PidData")
356 .field("fd", &self.pidfd)
357 .field("unregistered", &self.unregistered)
358 .finish_non_exhaustive()
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use std::sync::{Arc, Condvar};
365
366 use nix::sys::eventfd::EventFd;
367 use rustix::fd::AsRawFd;
368 use rustix::process::{WaitOptions, waitpid};
369
370 use super::*;
371
372 fn is_zombie(pid: Pid) -> bool {
373 let stat_name = format!("/proc/{}/stat", pid.as_raw_nonzero().get());
374 let contents = std::fs::read_to_string(stat_name).unwrap();
375 contents.contains(") Z")
376 }
377
378 #[test]
379 #[cfg_attr(miri, ignore)]
381 fn register_before_exit() {
382 let notifier = EventFd::new().unwrap();
383
384 let watcher = ChildPidWatcher::new();
385 let child = unsafe {
386 watcher.fork_watchable(|| {
387 let mut buf = [0; 8];
388 nix::unistd::read(notifier.as_raw_fd(), &mut buf).unwrap();
390 libc::_exit(42);
391 })
392 }
393 .unwrap();
394
395 let callback_ran = Arc::new((Mutex::new(false), Condvar::new()));
396 {
397 let callback_ran = callback_ran.clone();
398 watcher.register_callback(
399 child,
400 Box::new(move |pid| {
401 assert_eq!(pid, child);
402 *callback_ran.0.lock().unwrap() = true;
403 callback_ran.1.notify_all();
404 }),
405 );
406 }
407
408 watcher.unregister_pid(child);
412
413 let status = waitpid(Some(child.into()), WaitOptions::NOHANG).unwrap();
415 assert!(status.is_none(), "Unexpected status: {status:?}");
416
417 assert!(!*callback_ran.0.lock().unwrap());
419
420 nix::unistd::write(¬ifier, &1u64.to_ne_bytes()).unwrap();
422
423 let mut callback_ran_lock = callback_ran.0.lock().unwrap();
425 while !*callback_ran_lock {
426 callback_ran_lock = callback_ran.1.wait(callback_ran_lock).unwrap();
427 }
428
429 let status = waitpid(Some(child.into()), WaitOptions::empty())
434 .unwrap()
435 .unwrap();
436 assert_eq!(status.exit_status(), Some(42));
437 }
438
439 #[test]
440 #[cfg_attr(miri, ignore)]
442 fn register_after_exit() {
443 let child = match unsafe { libc::fork() } {
444 0 => {
445 unsafe { libc::_exit(42) };
446 }
447 child => Pid::from_raw(child).unwrap(),
448 };
449
450 while !is_zombie(child) {
452 unsafe {
453 libc::sched_yield();
454 }
455 }
456
457 let watcher = ChildPidWatcher::new();
458 watcher.register_pid(child);
459
460 let callback_ran = Arc::new((Mutex::new(false), Condvar::new()));
462 {
463 let callback_ran = callback_ran.clone();
464 watcher.register_callback(
465 child,
466 Box::new(move |pid| {
467 assert_eq!(pid, child);
468 *callback_ran.0.lock().unwrap() = true;
469 callback_ran.1.notify_all();
470 }),
471 );
472 }
473
474 watcher.unregister_pid(child);
478
479 let mut callback_ran_lock = callback_ran.0.lock().unwrap();
481 while !*callback_ran_lock {
482 callback_ran_lock = callback_ran.1.wait(callback_ran_lock).unwrap();
483 }
484
485 assert_eq!(
490 waitpid(Some(child.into()), WaitOptions::empty())
491 .unwrap()
492 .unwrap()
493 .exit_status(),
494 Some(42)
495 );
496 }
497
498 #[test]
499 #[cfg_attr(miri, ignore)]
501 fn register_multiple() {
502 let cb1_ran = Arc::new((Mutex::new(false), Condvar::new()));
503 let cb2_ran = Arc::new((Mutex::new(false), Condvar::new()));
504
505 let watcher = ChildPidWatcher::new();
506 let child = unsafe {
507 watcher.fork_watchable(|| {
508 libc::_exit(42);
509 })
510 }
511 .unwrap();
512
513 for cb_ran in vec![cb1_ran.clone(), cb2_ran.clone()].drain(..) {
514 let cb_ran = cb_ran.clone();
515 watcher.register_callback(
516 child,
517 Box::new(move |pid| {
518 assert_eq!(pid, child);
519 *cb_ran.0.lock().unwrap() = true;
520 cb_ran.1.notify_all();
521 }),
522 );
523 }
524
525 watcher.unregister_pid(child);
529
530 for cb_ran in vec![cb1_ran, cb2_ran].drain(..) {
531 let mut cb_ran_lock = cb_ran.0.lock().unwrap();
532 while !*cb_ran_lock {
533 cb_ran_lock = cb_ran.1.wait(cb_ran_lock).unwrap();
534 }
535 }
536
537 assert_eq!(
542 waitpid(Some(child.into()), WaitOptions::empty())
543 .unwrap()
544 .unwrap()
545 .exit_status(),
546 Some(42)
547 );
548 }
549
550 #[test]
551 #[cfg_attr(miri, ignore)]
553 fn unregister_one() {
554 let cb1_ran = Arc::new((Mutex::new(false), Condvar::new()));
555 let cb2_ran = Arc::new((Mutex::new(false), Condvar::new()));
556
557 let notifier = EventFd::new().unwrap();
558
559 let watcher = ChildPidWatcher::new();
560 let child = unsafe {
561 watcher.fork_watchable(|| {
562 let mut buf = [0; 8];
563 nix::unistd::read(notifier.as_raw_fd(), &mut buf).unwrap();
565 libc::_exit(42);
566 })
567 }
568 .unwrap();
569
570 let handles: Vec<WatchHandle> = [&cb1_ran, &cb2_ran]
571 .iter()
572 .cloned()
573 .map(|cb_ran| {
574 let cb_ran = cb_ran.clone();
575 watcher.register_callback(
576 child,
577 Box::new(move |pid| {
578 assert_eq!(pid, child);
579 *cb_ran.0.lock().unwrap() = true;
580 cb_ran.1.notify_all();
581 }),
582 )
583 })
584 .collect();
585
586 watcher.unregister_pid(child);
590
591 watcher.unregister_callback(child, handles[0]);
592
593 nix::unistd::write(¬ifier, &1u64.to_ne_bytes()).unwrap();
595
596 let mut cb_ran_lock = cb2_ran.0.lock().unwrap();
598 while !*cb_ran_lock {
599 cb_ran_lock = cb2_ran.1.wait(cb_ran_lock).unwrap();
600 }
601
602 assert!(!*cb1_ran.0.lock().unwrap());
604
605 assert_eq!(
610 waitpid(Some(child.into()), WaitOptions::empty())
611 .unwrap()
612 .unwrap()
613 .exit_status(),
614 Some(42)
615 );
616 }
617}