shadow_rs/utility/
status_bar.rs

1use std::io::Write;
2use std::sync::atomic::AtomicBool;
3use std::sync::{Arc, RwLock};
4use std::time::Duration;
5
6const SAVE_CURSOR: &str = "\u{1B}[s";
7const RESTORE_CURSOR: &str = "\u{1B}[u";
8const NEXT_LINE: &str = "\u{1B}[1E";
9const PREV_LINE: &str = "\u{1B}[1F";
10const CLEAR: &str = "\u{1B}[K";
11const RESTORE_SCROLL_REGION: &str = "\u{1B}[r";
12const LAST_LINE: &str = "\u{1B}[9999H";
13
14pub trait StatusBarState: std::fmt::Display + std::marker::Send + std::marker::Sync {}
15impl<T> StatusBarState for T where T: std::fmt::Display + std::marker::Send + std::marker::Sync {}
16
17pub struct StatusBar<T: 'static + StatusBarState> {
18    state: Arc<Status<T>>,
19    stop_flag: Arc<AtomicBool>,
20    thread: Option<std::thread::JoinHandle<()>>,
21}
22
23impl<T: 'static + StatusBarState> StatusBar<T> {
24    /// Create and start drawing the status bar.
25    pub fn new(state: T, redraw_interval: Duration) -> Self {
26        let state = Arc::new(Status::new(state));
27        let stop_flag = Arc::new(AtomicBool::new(false));
28
29        Self {
30            state: Arc::clone(&state),
31            stop_flag: Arc::clone(&stop_flag),
32            thread: Some(std::thread::spawn(move || {
33                Self::redraw_loop(state, stop_flag, redraw_interval);
34            })),
35        }
36    }
37
38    fn redraw_loop(state: Arc<Status<T>>, stop_flag: Arc<AtomicBool>, redraw_interval: Duration) {
39        // we re-draw the status bar every interval, even if the state hasn't changed, since the
40        // terminal might have been resized and the scroll region might have been reset
41        while !stop_flag.load(std::sync::atomic::Ordering::Acquire) {
42            // the window size might change during the simulation, so we re-check it each time
43            let rows = match tiocgwinsz() {
44                Ok(x) => x.ws_row,
45                Err(e) => {
46                    log::error!("Status bar ioctl failed ({}). Stopping the status bar.", e);
47                    break;
48                }
49            };
50
51            if rows > 1 {
52                #[rustfmt::skip]
53                let to_print = [
54                    // Restore the scroll region since some terminals handle scroll regions
55                    // differently. For example, when using '{next_line}' some terminals will
56                    // allow the cursor to move outside of the scroll region, and others don't.
57                    SAVE_CURSOR, RESTORE_SCROLL_REGION, RESTORE_CURSOR,
58                    // This will scroll the buffer up only if the cursor is on the last row.
59                    SAVE_CURSOR, "\n", RESTORE_CURSOR,
60                    // This will move the cursor up only if the cursor is on the last row (to
61                    // match the previous scroll behaviour).
62                    NEXT_LINE, PREV_LINE,
63                    // The cursor is currently at the correct location, so save it for later.
64                    SAVE_CURSOR,
65                    // Set the scroll region to include all rows but the last.
66                    &format!("\u{1B}[1;{}r", rows - 1),
67                    // Move to the last row and write the message.
68                    LAST_LINE, &format!("{}", *state.inner.read().unwrap()), CLEAR,
69                    // Restore the cursor position.
70                    RESTORE_CURSOR,
71                ]
72                .join("");
73
74                // We want to write everything in as few write() syscalls as possible. Note that
75                // if we were to use eprint! with a format string like "{}{}", eprint! would
76                // always make at least two write() syscalls, which we wouldn't want.
77                std::io::stderr().write_all(to_print.as_bytes()).unwrap();
78                let _ = std::io::stderr().flush();
79            }
80            std::thread::sleep(redraw_interval);
81        }
82
83        let to_print = format!(
84            "{save_cursor}{last_line}{clear}{restore_scroll_region}{restore_cursor}",
85            save_cursor = SAVE_CURSOR,
86            last_line = LAST_LINE,
87            clear = CLEAR,
88            restore_scroll_region = RESTORE_SCROLL_REGION,
89            restore_cursor = RESTORE_CURSOR,
90        );
91
92        std::io::stderr().write_all(to_print.as_bytes()).unwrap();
93        let _ = std::io::stderr().flush();
94    }
95
96    /// Stop and remove the status bar.
97    pub fn stop(self) {
98        // will be stopped in the drop handler
99    }
100
101    pub fn status(&self) -> &Arc<Status<T>> {
102        &self.state
103    }
104}
105
106impl<T: 'static + StatusBarState> std::ops::Drop for StatusBar<T> {
107    fn drop(&mut self) {
108        self.stop_flag
109            .swap(true, std::sync::atomic::Ordering::Relaxed);
110        if let Some(handle) = self.thread.take() {
111            if let Err(e) = handle.join() {
112                log::warn!("Progress bar thread did not exit cleanly: {:?}", e);
113            }
114        }
115    }
116}
117
118pub struct StatusPrinter<T: 'static + StatusBarState> {
119    state: Arc<Status<T>>,
120    stop_sender: Option<std::sync::mpsc::Sender<()>>,
121    thread: Option<std::thread::JoinHandle<()>>,
122}
123
124impl<T: 'static + StatusBarState> StatusPrinter<T> {
125    /// Create and start printing the status.
126    pub fn new(state: T) -> Self {
127        let state = Arc::new(Status::new(state));
128        let (stop_sender, stop_receiver) = std::sync::mpsc::channel();
129
130        Self {
131            state: Arc::clone(&state),
132            stop_sender: Some(stop_sender),
133            thread: Some(std::thread::spawn(move || {
134                Self::print_loop(state, stop_receiver);
135            })),
136        }
137    }
138
139    fn print_loop(state: Arc<Status<T>>, stop_receiver: std::sync::mpsc::Receiver<()>) {
140        let print_interval = Duration::from_secs(60);
141
142        loop {
143            match stop_receiver.recv_timeout(print_interval) {
144                // the sender disconnects to signal that we should stop
145                Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => break,
146                Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {}
147                Ok(()) => unreachable!(),
148            }
149
150            // We want to write everything in as few write() syscalls as possible. Note that
151            // if we were to use eprint! with a format string like "{}{}", eprint! would
152            // always make at least two write() syscalls, which we wouldn't want.
153            let to_write = format!("Progress: {}\n", *state.inner.read().unwrap());
154            std::io::stderr().write_all(to_write.as_bytes()).unwrap();
155            let _ = std::io::stderr().flush();
156        }
157    }
158
159    /// Stop printing the status.
160    pub fn stop(self) {
161        // will be stopped in the drop handler
162    }
163
164    pub fn status(&self) -> &Arc<Status<T>> {
165        &self.state
166    }
167}
168
169impl<T: 'static + StatusBarState> std::ops::Drop for StatusPrinter<T> {
170    fn drop(&mut self) {
171        // drop the sender to disconnect it
172        self.stop_sender.take();
173        if let Some(handle) = self.thread.take() {
174            if let Err(e) = handle.join() {
175                log::warn!("Progress thread did not exit cleanly: {:?}", e);
176            }
177        }
178    }
179}
180
181/// The status bar's internal state.
182#[derive(Debug)]
183pub struct Status<T> {
184    // we wrap an RwLock to hide the implementation details, for example we might want to replace
185    // this with a faster-writing lock in the future
186    inner: RwLock<T>,
187}
188
189impl<T> Status<T> {
190    fn new(inner: T) -> Self {
191        Self {
192            inner: RwLock::new(inner),
193        }
194    }
195
196    /// Update the status bar's internal state. The status will be shown to the user the next time
197    /// that the status bar redraws.
198    pub fn update(&self, f: impl FnOnce(&mut T)) {
199        f(&mut *self.inner.write().unwrap())
200    }
201}
202
203nix::ioctl_read_bad!(_tiocgwinsz, libc::TIOCGWINSZ, libc::winsize);
204
205fn tiocgwinsz() -> nix::Result<libc::winsize> {
206    let mut win_size: libc::winsize = unsafe { std::mem::zeroed() };
207    unsafe { _tiocgwinsz(0, &mut win_size)? };
208    Ok(win_size)
209}