shadow_rs/utility/
macros.rs

1/// Log a warning, and if a debug build then panic.
2macro_rules! debug_panic {
3    ($($x:tt)+) => {
4        log::warn!($($x)+);
5        #[cfg(debug_assertions)]
6        panic!($($x)+);
7    };
8}
9
10/// Log a message once at level `lvl_once`, and any later log messages from this line at level
11/// `lvl_remaining`.
12///
13/// A log target is not supported. The string "(LOG_ONCE)" will be prepended to the message to
14/// indicate that future messages won't be logged at `lvl_once`.
15///
16/// ```
17/// # use log::Level;
18/// # use shadow_rs::log_once_at_level;
19/// log_once_at_level!(Level::Warn, Level::Debug, "Unexpected flag {}", 10);
20/// ```
21#[allow(unused_macros)]
22#[macro_export]
23macro_rules! log_once_at_level {
24    ($lvl_once:expr, $lvl_remaining:expr, $str:literal $($x:tt)*) => {
25        // don't do atomic operations if this log statement isn't enabled
26        if log::log_enabled!($lvl_once) || log::log_enabled!($lvl_remaining) {
27            static HAS_LOGGED: std::sync::atomic::AtomicBool =
28                std::sync::atomic::AtomicBool::new(false);
29
30            // TODO: doing just a `load()` might be faster in the typical case, but would need to
31            // have performance metrics to back that up
32            match HAS_LOGGED.compare_exchange(
33                false,
34                true,
35                std::sync::atomic::Ordering::Relaxed,
36                std::sync::atomic::Ordering::Relaxed,
37            ) {
38                Ok(_) => log::log!($lvl_once, "(LOG_ONCE) {}", format_args!($str $($x)*)),
39                Err(_) => log::log!($lvl_remaining, "(LOG_ONCE) {}", format_args!($str $($x)*)),
40            }
41        }
42    };
43}
44
45/// Log a message once at level `lvl_once` for each distinct value, and any
46/// later log messages from this line with an already-logged value at level
47/// `lvl_remaining`.
48///
49/// A log target is not supported. The string "(LOG_ONCE)" will be prepended to
50/// the message to indicate that future messages won't be logged at `lvl_once`.
51///
52/// The fast-path (where the given value has already been logged) aquires a
53/// read-lock and looks up the value in a hash table.
54///
55/// ```
56/// # use log::Level;
57/// # use shadow_rs::log_once_per_value_at_level;
58/// # let unknown_flag: i32 = 0;
59/// log_once_per_value_at_level!(unknown_flag, i32, Level::Warn, Level::Debug, "Unknown flag value {unknown_flag}");
60/// ```
61#[allow(unused_macros)]
62#[macro_export]
63macro_rules! log_once_per_value_at_level {
64    ($value:expr, $t:ty, $lvl_once:expr, $lvl_remaining:expr, $str:literal $($x:tt)*) => {
65        // don't do atomic operations if this log statement isn't enabled
66        if log::log_enabled!($lvl_once) || log::log_enabled!($lvl_remaining) {
67            use $crate::utility::once_set::OnceSet;
68            static LOGGED_SET : OnceSet<$t> = OnceSet::new();
69
70            let level = if LOGGED_SET.insert($value) {
71                $lvl_once
72            } else {
73                $lvl_remaining
74            };
75            log::log!(level, "(LOG_ONCE) {}", format_args!($str $($x)*))
76        }
77    };
78}
79
80/// Log a message once at warn level, and any later log messages from this line at debug level. A
81/// log target is not supported. The string "(LOG_ONCE)" will be prepended to the message to
82/// indicate that future messages won't be logged at warn level.
83///
84/// ```ignore
85/// warn_once_then_debug!("Unexpected flag {}", 10);
86/// ```
87#[allow(unused_macros)]
88macro_rules! warn_once_then_debug {
89    ($($x:tt)+) => {
90        log_once_at_level!(log::Level::Warn, log::Level::Debug, $($x)+);
91    };
92}
93
94/// Log a message once at warn level, and any later log messages from this line at trace level. A
95/// log target is not supported. The string "(LOG_ONCE)" will be prepended to the message to
96/// indicate that future messages won't be logged at warn level.
97///
98/// ```ignore
99/// warn_once_then_trace!("Unexpected flag {}", 10);
100/// ```
101#[allow(unused_macros)]
102macro_rules! warn_once_then_trace {
103    ($($x:tt)+) => {
104        log_once_at_level!(log::Level::Warn, log::Level::Trace, $($x)+);
105    };
106}
107
108/// Implements logging functions that were generated by the `log_syscall` macro.
109pub struct SyscallLogger;
110
111/// Creates a logging function. This is written so that the macro can be called from within an
112/// `impl` block, ideally directly before the syscall function is defined. See the macro definition
113/// for the exact argument types that must be provided to the generated function. The macro itself
114/// takes the syscall name, the return type, and the argument types.
115///
116/// The macro:
117///
118/// ```ignore
119/// log_syscall!(close, /* rv */ c_int, /* fd */ c_int);
120/// ```
121///
122/// expands to something like (excluding some extra boilerplate):
123///
124/// ```ignore
125/// impl SyscallLogger {
126///     pub fn close(...) -> std::io::Result<()> { ... }
127/// }
128/// ```
129///
130/// This generated function can later be called using:
131///
132/// ```ignore
133/// SyscallLogger::close(...)?;
134/// ```
135macro_rules! log_syscall {
136    ($name:ident, $rv:ty $(,)?) => {
137        log_syscall!($name, $rv,,);
138    };
139    ($name:ident, $rv:ty, $($args:ty),* $(,)?) => {
140        paste::paste! { log_syscall!([< _syscall_logger_ $name >]; $name, $rv, $($args),*); }
141    };
142    ($const_name:ident; $name:ident, $rv:ty, $($args:ty),*) => {
143        // We use a constant as a hack so that we can do "impl SyscallLogger { ... }" while already
144        // inside a "impl SyscallHandler { ... }" block. Apparently they may make this a hard error
145        // (with no way to opt-out with an `allow`) in the future:
146        // https://github.com/rust-lang/rust/issues/120363
147        #[doc(hidden)]
148        #[allow(non_upper_case_globals)]
149        #[allow(non_local_definitions)]
150        const $const_name : () = {
151            impl crate::utility::macros::SyscallLogger {
152                pub fn $name(
153                    writer: impl std::io::Write,
154                    args: [shadow_shim_helper_rs::syscall_types::SyscallReg; 6],
155                    rv: &crate::host::syscall::types::SyscallResult,
156                    fmt: crate::host::syscall::formatter::FmtOptions,
157                    tid: crate::host::thread::ThreadId,
158                    mem: &crate::host::memory_manager::MemoryManager,
159                ) -> std::io::Result<()>
160                {
161                    let syscall_args = <crate::host::syscall::formatter::SyscallArgsFmt::<$($args),*>>::new(args, fmt, mem);
162                    let syscall_rv = crate::host::syscall::formatter::SyscallResultFmt::<$rv>::new(&rv, args, fmt, mem);
163
164                    crate::host::syscall::formatter::write_syscall(
165                        writer,
166                        &crate::host::syscall::handler::Worker::current_time().unwrap(),
167                        tid,
168                        std::stringify!($name),
169                        syscall_args,
170                        syscall_rv,
171                    )
172                }
173            }
174        };
175    };
176}
177
178/// Returns `None` if any field is not aligned, or if the bytes slice is too small to contain all
179/// fields.
180macro_rules! field_project {
181    ($bytes:expr, $type:ty, $field1:ident) => {
182        field_project!($bytes, $type, ($field1,)).map(|x| x.0)
183    };
184    ($bytes:expr, $type:ty, ($field1:ident,)) => {
185        field_project!(@ $bytes, $type, ($field1: A))
186    };
187    ($bytes:expr, $type:ty, ($field1:ident, $field2:ident)) => {
188        field_project!(@ $bytes, $type, ($field1: A), ($field2: B))
189    };
190    ($bytes:expr, $type:ty, ($field1:ident, $field2:ident, $field3:ident)) => {
191        field_project!(@ $bytes, $type, ($field1: A), ($field2: B), ($field3: C))
192    };
193    (@ $bytes:expr, $type:ty, $(($field:ident: $generic:ident)),*) => {{
194        // perform early type checking; we need `MaybeUninit<u8>` rather than just `u8`, otherwise
195        // this macro could be used to write uninitialized padding bytes to a `u8` slice
196        let bytes: &mut [std::mem::MaybeUninit<u8>] = $bytes;
197
198        const UNINIT: *const $type = std::mem::MaybeUninit::uninit().as_ptr();
199
200        const fn size_of_pointee<T>(_x: *const T) -> usize {
201            std::mem::size_of::<T>()
202        }
203
204        // This function is needed to:
205        // - ensure the type is `Pod`
206        // - link the lifetime of `bytes` to the return value's lifetime (we don't want to return a
207        //   'static lifetime by accident)
208        // - return the correct type for the field, which afaik is only available through the
209        //   `addr_of` macro
210        fn field_project<$( $generic: shadow_pod::Pod ),*>(
211            bytes: &mut [std::mem::MaybeUninit<u8>],
212            _for_type_coercion: ($( *const $generic ),*,)
213        ) -> Option<($( &mut std::mem::MaybeUninit<$generic> ),*,)> {
214            // the byte ranges of each field
215            const RANGES: &[std::ops::Range<usize>] = &[ $( {
216                const OFFSET: usize = std::mem::offset_of!($type, $field);
217                const SIZE: usize = size_of_pointee(unsafe { std::ptr::addr_of!((*UNINIT).$field) });
218                OFFSET..(OFFSET+SIZE)
219            } ),* ];
220
221            // check that no byte ranges are overlapping
222            const {
223                let mut i = 0;
224                while i < RANGES.len() {
225                    let mut j = i+1;
226                    while j < RANGES.len() {
227                        if RANGES[i].start < RANGES[j].end && RANGES[j].start < RANGES[i].end {
228                            panic!("Byte ranges overlap");
229                        }
230                        j += 1;
231                    }
232                    i += 1;
233                }
234            }
235
236            // check that no byte ranges have the same start (don't want two mutable references to
237            // the same ZST)
238            const {
239                let mut i = 0;
240                while i < RANGES.len() {
241                    let mut j = i+1;
242                    while j < RANGES.len() {
243                        assert!(RANGES[i].start != RANGES[j].start, "Byte ranges overlap (ZST)");
244                        j += 1;
245                    }
246                    i += 1;
247                }
248            }
249
250            // get the maximum of all byte ranges
251            const RANGE_MAX: usize = {
252                let mut max = 0;
253                let mut i = 0;
254                while i < RANGES.len() {
255                    if RANGES[i].end > max {
256                        max = RANGES[i].end;
257                    }
258                    i += 1;
259                }
260                max
261            };
262
263            // make sure a field does not exist outside of `bytes`
264            if RANGE_MAX > bytes.len() {
265                return None;
266            }
267
268            let bytes = bytes.as_mut_ptr();
269
270            // return the references to each field as a tuple
271            Some(( $( {
272                // NOTE: do not access the original 'bytes' slice within this block, otherwise it
273                // causes stacked borrows issues
274                const OFFSET: usize = std::mem::offset_of!($type, $field);
275
276                // SAFETY: we've already checked that the field offset is within the bounds of the
277                // bytes
278                let ptr = unsafe { bytes.add(OFFSET) } as *mut std::mem::MaybeUninit<$generic>;
279                if !ptr.is_aligned() {
280                    return None;
281                }
282                // SAFETY:
283                // - "The pointer must be properly aligned." - checked above
284                // - "It must be 'dereferenceable' in the sense defined in the module
285                //   documentation." - points to valid memory within a single allocated object, is
286                //   non-null
287                // - "The pointer must point to an initialized instance of T." - the pointer is a MaybeUninit
288                // - "You must enforce Rust’s aliasing rules, since the returned lifetime 'a is
289                //   arbitrarily chosen and does not necessarily reflect the actual lifetime of the
290                //   data. In particular, while this reference exists, the memory the pointer points
291                //   to must not get accessed (read or written) through any other pointer." - the
292                //   outer function makes sure that the returned reference has the correct lifetime
293                unsafe { ptr.as_mut() }.unwrap()
294            } ),*, ))
295        }
296
297        // there's no way to find the types of the fields directly, so we need to get values whose
298        // types contain the types of the fields and let rust use type inference to cast to the
299        // correct types
300        let addr_of_fields = ($( const { unsafe { std::ptr::addr_of!((*UNINIT).$field) } } ),*,);
301        field_project(bytes, addr_of_fields)
302    }};
303}
304
305#[cfg(test)]
306mod tests {
307    // will panic in debug mode
308    #[test]
309    #[cfg(debug_assertions)]
310    #[should_panic]
311    fn debug_panic_macro() {
312        debug_panic!("Hello {}", "World");
313    }
314
315    // will *not* panic in release mode
316    #[test]
317    #[cfg(not(debug_assertions))]
318    fn debug_panic_macro() {
319        debug_panic!("Hello {}", "World");
320    }
321
322    #[test]
323    fn log_once_at_level() {
324        // we don't have a logger set up so we can't actually inspect the log output (well we
325        // probably could with a custom logger), so instead we just make sure it compiles
326        for x in 0..10 {
327            log_once_at_level!(log::Level::Warn, log::Level::Debug, "{x}");
328        }
329
330        log_once_at_level!(log::Level::Warn, log::Level::Debug, "A");
331        log_once_at_level!(log::Level::Warn, log::Level::Debug, "A");
332
333        // expected log output is:
334        // Warn: 0
335        // Debug: 1
336        // Debug: 2
337        // ...
338        // Warn: A
339        // Warn: A
340    }
341
342    #[test]
343    fn warn_once() {
344        warn_once_then_trace!("A");
345        warn_once_then_debug!("A");
346    }
347
348    #[test]
349    fn field_project_1() {
350        let mut foo: libc::nlmsghdr = shadow_pod::zeroed();
351        let foo_bytes = unsafe { shadow_pod::as_u8_slice_mut(&mut foo) };
352
353        let foo_nlmsg_type = field_project!(foo_bytes, libc::nlmsghdr, nlmsg_type).unwrap();
354
355        foo_nlmsg_type.write(10);
356
357        assert_eq!(foo.nlmsg_type, 10);
358    }
359
360    #[test]
361    fn field_project_2() {
362        let mut foo: libc::nlmsghdr = shadow_pod::zeroed();
363        let foo_bytes = unsafe { shadow_pod::as_u8_slice_mut(&mut foo) };
364
365        let (foo_nlmsg_type, foo_nlmsg_flags) =
366            field_project!(foo_bytes, libc::nlmsghdr, (nlmsg_type, nlmsg_flags)).unwrap();
367
368        foo_nlmsg_type.write(10);
369        foo_nlmsg_flags.write(20);
370
371        // make sure the order we access the fields doesn't matter (no stacked borrows miri errors)
372        foo_nlmsg_flags.write(40);
373        foo_nlmsg_type.write(30);
374
375        assert_eq!(foo.nlmsg_type, 30);
376        assert_eq!(foo.nlmsg_flags, 40);
377    }
378
379    #[test]
380    fn field_project_type_inference() {
381        let mut foo: libc::nlmsghdr = shadow_pod::zeroed();
382        let foo_bytes = unsafe { shadow_pod::as_u8_slice_mut(&mut foo) };
383
384        // make sure field_project returns a u16 reference (ideally we'd want a test that uses an
385        // incorrect type and makes sure that the code fails to build to make sure that rust's type
386        // inference isn't leading to incorrect code, but writing rust tests that check that code
387        // fails to compile isn't supported and the workarounds aren't very nice)
388        let _nlmsg_type: &mut std::mem::MaybeUninit<u16> =
389            field_project!(foo_bytes, libc::nlmsghdr, nlmsg_type).unwrap();
390    }
391
392    #[test]
393    fn field_project_range() {
394        let mut foo: libc::nlmsghdr = shadow_pod::zeroed();
395        let foo_bytes = unsafe { shadow_pod::as_u8_slice_mut(&mut foo) };
396
397        // #[repr(C)]
398        // pub struct nlmsghdr {
399        //     pub nlmsg_len: u32,
400        //     pub nlmsg_type: u16,
401        //     ...
402        assert!(field_project!(&mut foo_bytes[..0], libc::nlmsghdr, nlmsg_type).is_none());
403        assert!(field_project!(&mut foo_bytes[..5], libc::nlmsghdr, nlmsg_type).is_none());
404        assert!(field_project!(&mut foo_bytes[..6], libc::nlmsghdr, nlmsg_type).is_some());
405    }
406
407    #[test]
408    fn field_project_align() {
409        let mut foo: libc::nlmsghdr = shadow_pod::zeroed();
410        let foo_bytes = unsafe { shadow_pod::as_u8_slice_mut(&mut foo) };
411
412        // #[repr(C)]
413        // pub struct nlmsghdr {
414        //     pub nlmsg_len: u32,
415        //     pub nlmsg_type: u16,
416        //     ...
417        assert!(field_project!(&mut foo_bytes[..], libc::nlmsghdr, nlmsg_type).is_some());
418        assert!(field_project!(&mut foo_bytes[1..], libc::nlmsghdr, nlmsg_type).is_none());
419    }
420}