neli/
utils.rs

1//! A module containing utilities for working with constructs like
2//! bitflags and other low level operations.
3//!
4//! # Design decisions
5//! Some of the less documented aspects of interacting with netlink
6//! are handled internally in the types so that the user does not
7//! have to be aware of them.
8
9use std::mem::size_of;
10
11#[cfg(any(feature = "sync", feature = "async"))]
12use crate::consts::MAX_NL_LENGTH;
13use crate::err::MsgError;
14
15type BitArrayType = u32;
16
17/// A bit array meant to be compatible with the bit array
18/// returned by the `NETLINK_LIST_MEMBERSHIPS` socket operation
19/// on netlink sockets.
20pub struct NetlinkBitArray(Vec<BitArrayType>);
21
22/// bittest/bitset instrinsics are not stable in Rust so this
23/// needs to be implemented this way.
24#[allow(clippy::len_without_is_empty)]
25impl NetlinkBitArray {
26    const BIT_SIZE: usize = BitArrayType::BITS as usize;
27
28    /// Create a new bit array.
29    ///
30    /// This method will round `bit_len` up to the nearest
31    /// multiple of [`size_of::<u32>()`][std::mem::size_of].
32    pub fn new(bit_len: usize) -> Self {
33        let round = Self::BIT_SIZE - 1;
34        NetlinkBitArray(vec![0; ((bit_len + round) & !round) / Self::BIT_SIZE])
35    }
36
37    /// Resize the underlying vector to have enough space for
38    /// the nearest multiple of [`size_of::<u32>()`][std::mem::size_of]
39    /// rounded up.
40    pub fn resize_bits(&mut self, bit_len: usize) {
41        let round = Self::BIT_SIZE - 1;
42        self.0
43            .resize(((bit_len + round) & !round) / Self::BIT_SIZE, 0);
44    }
45
46    /// Resize the underlying vector to have enough space for
47    /// the nearest multiple of [`size_of::<BitArrayType>()`][std::mem::size_of].
48    pub fn resize(&mut self, bytes: usize) {
49        let byte_round = size_of::<BitArrayType>() - 1;
50        self.0.resize(
51            ((bytes + byte_round) & !byte_round) / size_of::<BitArrayType>(),
52            0,
53        );
54    }
55
56    /// Returns true if the `n`th bit is set. Not zero indexed.
57    pub fn is_set(&self, n: usize) -> bool {
58        if n == 0 {
59            return false;
60        }
61        let n_1 = n - 1;
62        let bit_segment = self.0[n_1 / Self::BIT_SIZE];
63        if let Some(bit_shifted_n) = u32::try_from(n_1 % Self::BIT_SIZE)
64            .ok()
65            .and_then(|rem| 1u32.checked_shl(rem))
66        {
67            bit_segment & bit_shifted_n == bit_shifted_n
68        } else {
69            false
70        }
71    }
72
73    /// Set the `n`th bit. Not zero indexed.
74    pub fn set(&mut self, n: usize) {
75        if n == 0 {
76            return;
77        }
78        let n_1 = n - 1;
79        let bit_segment = self.0[n_1 / Self::BIT_SIZE];
80        if let Some(bit_shifted_n) = u32::try_from(n_1 % Self::BIT_SIZE)
81            .ok()
82            .and_then(|rem| 1u32.checked_shl(rem))
83        {
84            self.0[n_1 / Self::BIT_SIZE] = bit_segment | bit_shifted_n;
85        }
86    }
87
88    /// Get a vector representation of all of the bit positions set
89    /// to 1 in this bit array.
90    ///
91    /// ## Example
92    /// ```
93    /// use neli::utils::NetlinkBitArray;
94    ///
95    /// let mut array = NetlinkBitArray::new(24);
96    /// array.set(4);
97    /// array.set(7);
98    /// array.set(23);
99    /// assert_eq!(array.to_vec(), vec![4, 7, 23]);
100    /// ```
101    pub fn to_vec(&self) -> Vec<u32> {
102        let mut bits = Vec::new();
103        for bit in 0..self.len_bits() {
104            let bit_shifted = 1 << (bit % Self::BIT_SIZE);
105            if bit_shifted & self.0[bit / Self::BIT_SIZE] == bit_shifted {
106                bits.push(bit as u32 + 1);
107            }
108        }
109        bits
110    }
111
112    /// Return the number of bits that can be contained in this bit
113    /// array.
114    pub fn len_bits(&self) -> usize {
115        self.0.len() * Self::BIT_SIZE
116    }
117
118    /// Return the length in bytes for this bit array.
119    pub fn len(&self) -> usize {
120        self.0.len() * size_of::<BitArrayType>()
121    }
122
123    pub(crate) fn as_mut_slice(&mut self) -> &mut [BitArrayType] {
124        self.0.as_mut_slice()
125    }
126}
127
128fn slice_to_mask(groups: &[u32]) -> Result<u32, MsgError> {
129    groups.iter().try_fold(0, |mask, next| {
130        if *next == 0 {
131            Ok(mask)
132        } else if next - 1 > 31 {
133            Err(MsgError::new(format!(
134                "Group {next} cannot be represented with a bit width of 32"
135            )))
136        } else {
137            Ok(mask | (1 << (*next - 1)))
138        }
139    })
140}
141
142fn mask_to_vec(mask: u32) -> Vec<u32> {
143    (1..size_of::<u32>() as u32 * u8::BITS)
144        .filter(|i| (1 << (i - 1)) & mask == (1 << (i - 1)))
145        .collect::<Vec<_>>()
146}
147
148/// Struct implementing handling of groups both as numerical values and as
149/// bitmasks.
150pub struct Groups(Vec<u32>);
151
152impl Groups {
153    /// Create an empty set of netlink multicast groups
154    pub fn empty() -> Self {
155        Groups(vec![])
156    }
157
158    /// Create a new set of groups with a bitmask. Each bit represents a group.
159    pub fn new_bitmask(mask: u32) -> Self {
160        Groups(mask_to_vec(mask))
161    }
162
163    /// Add a new bitmask to the existing group set. Each bit represents a group.
164    pub fn add_bitmask(&mut self, mask: u32) {
165        for group in mask_to_vec(mask) {
166            if !self.0.contains(&group) {
167                self.0.push(group);
168            }
169        }
170    }
171
172    /// Remove a bitmask from the existing group set. Each bit represents a group
173    /// and each bit set to 1 will be removed.
174    pub fn remove_bitmask(&mut self, mask: u32) {
175        let remove_items = mask_to_vec(mask);
176        self.0 = self
177            .0
178            .drain(..)
179            .filter(|g| !remove_items.contains(g))
180            .collect::<Vec<_>>();
181    }
182
183    /// Create a new set of groups from a list of numerical groups values. This differs
184    /// from the bitmask representation where the value 3 represents group 3 in this
185    /// format as opposed to 0x4 in the bitmask format.
186    pub fn new_groups(groups: &[u32]) -> Self {
187        let mut vec = groups.to_owned();
188        vec.retain(|g| g != &0);
189        Groups(vec)
190    }
191
192    /// Add a list of numerical groups values to the set of groups. This differs
193    /// from the bitmask representation where the value 3 represents group 3 in this
194    /// format as opposed to 0x4 in the bitmask format.
195    pub fn add_groups(&mut self, groups: &[u32]) {
196        for group in groups {
197            if *group != 0 && !self.0.contains(group) {
198                self.0.push(*group)
199            }
200        }
201    }
202
203    /// Remove a list of numerical groups values from the set of groups. This differs
204    /// from the bitmask representation where the value 3 represents group 3 in this
205    /// format as opposed to 0x4 in the bitmask format.
206    pub fn remove_groups(&mut self, groups: &[u32]) {
207        self.0.retain(|g| !groups.contains(g));
208    }
209
210    /// Return the set of groups as a bitmask. The representation of a bitmask is u32.
211    pub fn as_bitmask(&self) -> Result<u32, MsgError> {
212        slice_to_mask(&self.0)
213    }
214
215    /// Return the set of groups as a vector of group values.
216    pub fn as_groups(&self) -> Vec<u32> {
217        self.0.clone()
218    }
219
220    /// Return the set of groups as a vector of group values.
221    pub fn into_groups(self) -> Vec<u32> {
222        self.0
223    }
224
225    /// Returns true if no group is set.
226    pub fn is_empty(&self) -> bool {
227        self.0.is_empty()
228    }
229}
230
231/// Synchronous (blocking) utils.
232#[cfg(feature = "sync")]
233pub mod synchronous {
234    use super::*;
235
236    use std::{
237        mem::swap,
238        ops::{Deref, DerefMut},
239    };
240
241    use log::trace;
242    use parking_lot::{Condvar, Mutex};
243
244    /// Type containing information pertaining to the semaphore tracking.
245    struct SemInfo {
246        max: u64,
247        count: u64,
248    }
249
250    /// Guard indicating that a buffer has been acquired and the semaphore has been
251    /// incremented.
252    pub struct BufferPoolGuard<'a>(&'a BufferPool, Vec<u8>);
253
254    impl Deref for BufferPoolGuard<'_> {
255        type Target = Vec<u8>;
256
257        fn deref(&self) -> &Self::Target {
258            &self.1
259        }
260    }
261
262    impl DerefMut for BufferPoolGuard<'_> {
263        fn deref_mut(&mut self) -> &mut Self::Target {
264            &mut self.1
265        }
266    }
267
268    impl AsRef<[u8]> for BufferPoolGuard<'_> {
269        fn as_ref(&self) -> &[u8] {
270            self.1.as_ref()
271        }
272    }
273
274    impl AsMut<[u8]> for BufferPoolGuard<'_> {
275        fn as_mut(&mut self) -> &mut [u8] {
276            self.1.as_mut()
277        }
278    }
279
280    impl BufferPoolGuard<'_> {
281        /// Reduce the size of the internal buffer to the number of bytes read.
282        pub fn reduce_size(&mut self, bytes_read: usize) {
283            assert!(bytes_read <= self.1.len());
284            self.1.resize(bytes_read, 0);
285        }
286
287        /// Reset the buffer to the original size.
288        pub fn reset(&mut self) {
289            self.1.resize(
290                option_env!("NELI_AUTO_BUFFER_LEN")
291                    .and_then(|s| s.parse::<usize>().ok())
292                    .unwrap_or(MAX_NL_LENGTH),
293                0,
294            );
295        }
296    }
297
298    impl Drop for BufferPoolGuard<'_> {
299        fn drop(&mut self) {
300            {
301                let mut vec = Vec::new();
302                swap(&mut self.1, &mut vec);
303                let mut sem_info = self.0.sem_info.lock();
304                let mut pool = self.0.pool.lock();
305                sem_info.count -= 1;
306                vec.resize(
307                    option_env!("NELI_AUTO_BUFFER_LEN")
308                        .and_then(|s| s.parse::<usize>().ok())
309                        .unwrap_or(MAX_NL_LENGTH),
310                    0,
311                );
312                pool.push(vec);
313                trace!(
314                    "Semaphore released; current count is {}, available is {}",
315                    sem_info.count,
316                    sem_info.max - sem_info.count
317                );
318            }
319            self.0.condvar.notify_one();
320        }
321    }
322
323    /// A pool of buffers available for reading concurrent netlink messages without
324    /// truncation.
325    pub struct BufferPool {
326        pool: Mutex<Vec<Vec<u8>>>,
327        sem_info: Mutex<SemInfo>,
328        condvar: Condvar,
329    }
330
331    impl Default for BufferPool {
332        fn default() -> Self {
333            let max_parallel = option_env!("NELI_MAX_PARALLEL_READ_OPS")
334                .and_then(|s| s.parse::<u64>().ok())
335                .unwrap_or(3);
336            let buffer_size = option_env!("NELI_AUTO_BUFFER_LEN")
337                .and_then(|s| s.parse::<usize>().ok())
338                .unwrap_or(MAX_NL_LENGTH);
339
340            BufferPool {
341                pool: Mutex::new(
342                    (0..max_parallel)
343                        .map(|_| vec![0; buffer_size])
344                        .collect::<Vec<_>>(),
345                ),
346                sem_info: Mutex::new(SemInfo {
347                    count: 0,
348                    max: max_parallel,
349                }),
350                condvar: Condvar::new(),
351            }
352        }
353    }
354
355    impl BufferPool {
356        /// Acquire a buffer for use.
357        ///
358        /// This method is backed by a semaphore.
359        pub fn acquire(&self) -> BufferPoolGuard {
360            let mut sem_info = self.sem_info.lock();
361            self.condvar
362                .wait_while(&mut sem_info, |sem_info| sem_info.count >= sem_info.max);
363            let mut pool = self.pool.lock();
364            sem_info.count += 1;
365            trace!(
366                "Semaphore acquired; current count is {}, available is {}",
367                sem_info.count,
368                sem_info.max - sem_info.count
369            );
370            BufferPoolGuard(
371                self,
372                pool.pop()
373                    .expect("Checked that there is an available permit"),
374            )
375        }
376    }
377
378    #[cfg(test)]
379    mod tests {
380        use super::*;
381
382        use std::{
383            io::Write,
384            thread::{scope, sleep},
385            time::Duration,
386        };
387
388        use crate::test::setup;
389
390        #[test]
391        fn test_buffer_pool() {
392            setup();
393
394            let pool = BufferPool::default();
395            scope(|s| {
396                s.spawn(|| {
397                    let mut guard = pool.acquire();
398                    sleep(Duration::from_secs(2));
399                    guard.as_mut_slice().write_all(&[4]).unwrap();
400                    assert_eq!(Some(&4), guard.first());
401                });
402                s.spawn(|| {
403                    let mut guard = pool.acquire();
404                    sleep(Duration::from_secs(3));
405                    guard.as_mut_slice().write_all(&[1]).unwrap();
406                    assert_eq!(Some(&1), guard.first());
407                });
408                s.spawn(|| {
409                    let mut guard = pool.acquire();
410                    sleep(Duration::from_secs(3));
411                    guard.as_mut_slice().write_all(&[1]).unwrap();
412                    assert_eq!(Some(&1), guard.first());
413                });
414                s.spawn(|| {
415                    sleep(Duration::from_secs(1));
416                    let mut guard = pool.acquire();
417                    guard.as_mut_slice().write_all(&[1]).unwrap();
418                    assert_eq!(Some(&1), guard.first());
419                });
420            });
421            let pool = pool.pool.lock();
422            assert_eq!(pool.len(), 3);
423            for buf in pool.iter() {
424                assert_eq!(Some(&1), buf.first());
425            }
426        }
427    }
428}
429
430/// Asynchronous utils.
431#[cfg(feature = "async")]
432pub mod asynchronous {
433    use super::*;
434
435    use std::{
436        mem::swap,
437        ops::{Deref, DerefMut},
438    };
439
440    use log::trace;
441    use parking_lot::Mutex;
442    use tokio::sync::{Semaphore, SemaphorePermit};
443
444    /// Guard indicating that a buffer has been acquired and the semaphore has been
445    /// incremented.
446    #[allow(dead_code)]
447    pub struct BufferPoolGuard<'a>(&'a BufferPool, SemaphorePermit<'a>, Vec<u8>);
448
449    impl Deref for BufferPoolGuard<'_> {
450        type Target = Vec<u8>;
451
452        fn deref(&self) -> &Self::Target {
453            &self.2
454        }
455    }
456
457    impl DerefMut for BufferPoolGuard<'_> {
458        fn deref_mut(&mut self) -> &mut Self::Target {
459            &mut self.2
460        }
461    }
462
463    impl AsRef<[u8]> for BufferPoolGuard<'_> {
464        fn as_ref(&self) -> &[u8] {
465            self.2.as_ref()
466        }
467    }
468
469    impl AsMut<[u8]> for BufferPoolGuard<'_> {
470        fn as_mut(&mut self) -> &mut [u8] {
471            self.2.as_mut()
472        }
473    }
474
475    impl BufferPoolGuard<'_> {
476        /// Reduce the size of the internal buffer to the number of bytes read.
477        pub fn reduce_size(&mut self, bytes_read: usize) {
478            assert!(bytes_read <= self.2.len());
479            self.2.resize(bytes_read, 0);
480        }
481
482        /// Reset the buffer to the original size.
483        pub fn reset(&mut self) {
484            self.2.resize(
485                option_env!("NELI_AUTO_BUFFER_LEN")
486                    .and_then(|s| s.parse::<usize>().ok())
487                    .unwrap_or(MAX_NL_LENGTH),
488                0,
489            );
490        }
491    }
492
493    impl Drop for BufferPoolGuard<'_> {
494        fn drop(&mut self) {
495            {
496                let mut vec = Vec::new();
497                swap(&mut self.2, &mut vec);
498                let mut pool = self.0.pool.lock();
499                vec.resize(
500                    option_env!("NELI_AUTO_BUFFER_LEN")
501                        .and_then(|s| s.parse::<usize>().ok())
502                        .unwrap_or(MAX_NL_LENGTH),
503                    0,
504                );
505                pool.push(vec);
506                trace!(
507                    "Semaphore released; current count is {}, max is {}",
508                    self.0.max - self.0.semaphore.available_permits(),
509                    self.0.semaphore.available_permits()
510                );
511            }
512        }
513    }
514
515    /// A pool of buffers available for reading concurrent netlink messages without
516    /// truncation.
517    pub struct BufferPool {
518        pool: Mutex<Vec<Vec<u8>>>,
519        max: usize,
520        semaphore: Semaphore,
521    }
522
523    impl Default for BufferPool {
524        fn default() -> Self {
525            let max_parallel = option_env!("NELI_MAX_PARALLEL_READ_OPS")
526                .and_then(|s| s.parse::<usize>().ok())
527                .unwrap_or(3);
528            let buffer_size = option_env!("NELI_AUTO_BUFFER_LEN")
529                .and_then(|s| s.parse::<usize>().ok())
530                .unwrap_or(MAX_NL_LENGTH);
531
532            BufferPool {
533                pool: Mutex::new(
534                    (0..max_parallel)
535                        .map(|_| vec![0; buffer_size])
536                        .collect::<Vec<_>>(),
537                ),
538                max: max_parallel,
539                semaphore: Semaphore::new(max_parallel),
540            }
541        }
542    }
543
544    impl BufferPool {
545        /// Acquire a buffer for use.
546        ///
547        /// This method is backed by a semaphore.
548        pub async fn acquire(&self) -> BufferPoolGuard {
549            let permit = self
550                .semaphore
551                .acquire()
552                .await
553                .expect("Semaphore is never closed");
554            let mut pool = self.pool.lock();
555            trace!(
556                "Semaphore acquired; current count is {}, available is {}",
557                self.max - self.semaphore.available_permits(),
558                self.semaphore.available_permits(),
559            );
560            BufferPoolGuard(
561                self,
562                permit,
563                pool.pop()
564                    .expect("Checked that there is an available permit"),
565            )
566        }
567    }
568}
569
570#[cfg(test)]
571mod test {
572    use super::*;
573
574    use crate::test::setup;
575
576    #[test]
577    fn test_bit_array() {
578        setup();
579
580        let mut bit_array = NetlinkBitArray::new(7);
581        assert_eq!(bit_array.0.len(), 1);
582        bit_array.set(4);
583        assert_eq!(bit_array.0[0], 0b1000);
584        assert!(bit_array.is_set(4));
585        assert!(!bit_array.is_set(0));
586        assert!(!bit_array.is_set(1));
587        assert!(!bit_array.is_set(2));
588        assert!(!bit_array.is_set(3));
589
590        assert_eq!(bit_array.len(), 4);
591        assert_eq!(bit_array.len_bits(), 32);
592
593        let mut bit_array = NetlinkBitArray::new(33);
594        bit_array.set(32);
595        bit_array.set(33);
596        assert!(bit_array.0[0] == 1 << 31);
597        assert!(bit_array.0[1] == 1);
598        assert!(bit_array.is_set(32));
599        assert!(bit_array.is_set(33));
600
601        let mut bit_array = NetlinkBitArray::new(32);
602        assert_eq!(bit_array.len(), 4);
603        bit_array.resize_bits(33);
604        assert_eq!(bit_array.len(), 8);
605        bit_array.resize_bits(1);
606        assert_eq!(bit_array.len(), 4);
607
608        let mut bit_array = NetlinkBitArray::new(33);
609        assert_eq!(bit_array.len(), 8);
610        bit_array.resize(1);
611        assert_eq!(bit_array.len(), 4);
612        bit_array.resize(9);
613        assert_eq!(bit_array.len(), 12);
614
615        let bit_array = NetlinkBitArray(vec![8, 8, 8]);
616        assert_eq!(bit_array.to_vec(), vec![4, 36, 68]);
617    }
618
619    #[test]
620    fn test_groups() {
621        setup();
622
623        assert_eq!(Groups::new_groups(&[0, 0, 0, 0]).as_bitmask().unwrap(), 0);
624        let groups = Groups::new_groups(&[0, 0, 0, 0]).as_groups();
625        assert!(groups.is_empty());
626    }
627}