neli/
utils.rs

1//! A module containing utilities for working with constructs like
2//! bitflags and other low level operations.
3//!
4//! # Design decisions
5//! Some of the less documented aspects of interacting with netlink
6//! are handled internally in the types so that the user does not
7//! have to be aware of them.
8
9use std::mem::size_of;
10
11type BitArrayType = u32;
12
13/// A bit array meant to be compatible with the bit array
14/// returned by the `NETLINK_LIST_MEMBERSHIPS` socket operation
15/// on netlink sockets.
16pub struct NetlinkBitArray(Vec<BitArrayType>);
17
18/// bittest/bitset instrinsics are not stable in Rust so this
19/// needs to be implemented this way.
20#[allow(clippy::len_without_is_empty)]
21impl NetlinkBitArray {
22    const BIT_SIZE: usize = BitArrayType::BITS as usize;
23
24    /// Create a new bit array.
25    ///
26    /// This method will round `bit_len` up to the nearest
27    /// multiple of [`size_of::<u32>()`][std::mem::size_of].
28    pub fn new(bit_len: usize) -> Self {
29        let round = Self::BIT_SIZE - 1;
30        NetlinkBitArray(vec![0; ((bit_len + round) & !round) / Self::BIT_SIZE])
31    }
32
33    /// Resize the underlying vector to have enough space for
34    /// the nearest multiple of [`size_of::<u32>()`][std::mem::size_of]
35    /// rounded up.
36    pub fn resize_bits(&mut self, bit_len: usize) {
37        let round = Self::BIT_SIZE - 1;
38        self.0
39            .resize(((bit_len + round) & !round) / Self::BIT_SIZE, 0);
40    }
41
42    /// Resize the underlying vector to have enough space for
43    /// the nearest multiple of [`size_of::<BitArrayType>()`][std::mem::size_of].
44    pub fn resize(&mut self, bytes: usize) {
45        let byte_round = size_of::<BitArrayType>() - 1;
46        self.0.resize(
47            ((bytes + byte_round) & !byte_round) / size_of::<BitArrayType>(),
48            0,
49        );
50    }
51
52    /// Returns true if the `n`th bit is set.
53    pub fn is_set(&self, n: usize) -> bool {
54        if n == 0 {
55            return false;
56        }
57        let n_1 = n - 1;
58        let bit_segment = self.0[n_1 / Self::BIT_SIZE];
59        let bit_shifted_n = 1 << (n_1 % Self::BIT_SIZE);
60        bit_segment & bit_shifted_n == bit_shifted_n
61    }
62
63    /// Set the `n`th bit.
64    pub fn set(&mut self, n: usize) {
65        if n == 0 {
66            return;
67        }
68        let n_1 = n - 1;
69        let bit_segment = self.0[n_1 / Self::BIT_SIZE];
70        let bit_shifted_n = 1 << (n_1 % Self::BIT_SIZE);
71        self.0[n_1 / Self::BIT_SIZE] = bit_segment | bit_shifted_n;
72    }
73
74    /// Get a vector representation of all of the bit positions set
75    /// to 1 in this bit array.
76    ///
77    /// ## Example
78    /// ```
79    /// use neli::utils::NetlinkBitArray;
80    ///
81    /// let mut array = NetlinkBitArray::new(24);
82    /// array.set(4);
83    /// array.set(7);
84    /// array.set(23);
85    /// assert_eq!(array.to_vec(), vec![4, 7, 23]);
86    /// ```
87    pub fn to_vec(&self) -> Vec<u32> {
88        let mut bits = Vec::new();
89        for bit in 0..self.len_bits() {
90            let bit_shifted = 1 << (bit % Self::BIT_SIZE);
91            if bit_shifted & self.0[bit / Self::BIT_SIZE] == bit_shifted {
92                bits.push(bit as u32 + 1);
93            }
94        }
95        bits
96    }
97
98    /// Return the number of bits that can be contained in this bit
99    /// array.
100    pub fn len_bits(&self) -> usize {
101        self.0.len() * Self::BIT_SIZE
102    }
103
104    /// Return the length in bytes for this bit array.
105    pub fn len(&self) -> usize {
106        self.0.len() * size_of::<BitArrayType>()
107    }
108
109    pub(crate) fn as_mut_slice(&mut self) -> &mut [BitArrayType] {
110        self.0.as_mut_slice()
111    }
112}
113
114#[cfg(test)]
115mod test {
116    use super::*;
117
118    use crate::test::setup;
119
120    #[test]
121    fn test_bit_array() {
122        setup();
123
124        let mut bit_array = NetlinkBitArray::new(7);
125        assert_eq!(bit_array.0.len(), 1);
126        bit_array.set(4);
127        assert_eq!(bit_array.0[0], 0b1000);
128        assert!(bit_array.is_set(4));
129        assert!(!bit_array.is_set(0));
130        assert!(!bit_array.is_set(1));
131        assert!(!bit_array.is_set(2));
132        assert!(!bit_array.is_set(3));
133
134        assert_eq!(bit_array.len(), 4);
135        assert_eq!(bit_array.len_bits(), 32);
136
137        let mut bit_array = NetlinkBitArray::new(33);
138        bit_array.set(32);
139        bit_array.set(33);
140        assert!(bit_array.0[0] == 1 << 31);
141        assert!(bit_array.0[1] == 1);
142        assert!(bit_array.is_set(32));
143        assert!(bit_array.is_set(33));
144
145        let mut bit_array = NetlinkBitArray::new(32);
146        assert_eq!(bit_array.len(), 4);
147        bit_array.resize_bits(33);
148        assert_eq!(bit_array.len(), 8);
149        bit_array.resize_bits(1);
150        assert_eq!(bit_array.len(), 4);
151
152        let mut bit_array = NetlinkBitArray::new(33);
153        assert_eq!(bit_array.len(), 8);
154        bit_array.resize(1);
155        assert_eq!(bit_array.len(), 4);
156        bit_array.resize(9);
157        assert_eq!(bit_array.len(), 12);
158
159        let bit_array = NetlinkBitArray(vec![8, 8, 8]);
160        assert_eq!(bit_array.to_vec(), vec![4, 36, 68]);
161    }
162}