neli/socket/
shared.rs

1use std::{
2    io,
3    mem::{MaybeUninit, size_of, zeroed},
4    os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
5};
6
7use libc::{c_int, c_void, sockaddr, sockaddr_nl};
8
9#[cfg(feature = "async")]
10use crate::socket::asynchronous;
11#[cfg(feature = "sync")]
12use crate::socket::synchronous;
13use crate::{
14    consts::socket::*,
15    utils::{Groups, NetlinkBitArray},
16};
17
18/// Low level access to a netlink socket.
19pub struct NlSocket {
20    fd: c_int,
21}
22
23impl NlSocket {
24    /// Wrapper around `socket()` syscall filling in the
25    /// netlink-specific information.
26    pub fn new(proto: NlFamily) -> Result<Self, io::Error> {
27        let fd = match unsafe {
28            libc::socket(
29                AddrFamily::Netlink.into(),
30                libc::SOCK_RAW | libc::SOCK_CLOEXEC,
31                proto.into(),
32            )
33        } {
34            i if i >= 0 => Ok(i),
35            _ => Err(io::Error::last_os_error()),
36        }?;
37        Ok(NlSocket { fd })
38    }
39
40    /// Equivalent of `socket` and `bind` calls.
41    pub fn connect(proto: NlFamily, pid: Option<u32>, groups: Groups) -> Result<Self, io::Error> {
42        let s = NlSocket::new(proto)?;
43        s.bind(pid, groups)?;
44        Ok(s)
45    }
46
47    /// Set underlying socket file descriptor to be blocking.
48    pub fn block(&self) -> Result<(), io::Error> {
49        match unsafe {
50            libc::fcntl(
51                self.fd,
52                libc::F_SETFL,
53                libc::fcntl(self.fd, libc::F_GETFL, 0) & !libc::O_NONBLOCK,
54            )
55        } {
56            i if i < 0 => Err(io::Error::last_os_error()),
57            _ => Ok(()),
58        }
59    }
60
61    /// Set underlying socket file descriptor to be non blocking.
62    pub fn nonblock(&self) -> Result<(), io::Error> {
63        match unsafe {
64            libc::fcntl(
65                self.fd,
66                libc::F_SETFL,
67                libc::fcntl(self.fd, libc::F_GETFL, 0) | libc::O_NONBLOCK,
68            )
69        } {
70            i if i < 0 => Err(io::Error::last_os_error()),
71            _ => Ok(()),
72        }
73    }
74
75    /// Determines if underlying file descriptor is blocking.
76    pub fn is_blocking(&self) -> Result<bool, io::Error> {
77        let is_blocking = match unsafe { libc::fcntl(self.fd, libc::F_GETFL, 0) } {
78            i if i >= 0 => i & libc::O_NONBLOCK == 0,
79            _ => return Err(io::Error::last_os_error()),
80        };
81        Ok(is_blocking)
82    }
83
84    /// Use this function to bind to a netlink ID and subscribe to
85    /// groups. See netlink(7) man pages for more information on
86    /// netlink IDs and groups.
87    pub fn bind(&self, pid: Option<u32>, groups: Groups) -> Result<(), io::Error> {
88        let mut nladdr = unsafe { zeroed::<libc::sockaddr_nl>() };
89        nladdr.nl_family = c_int::from(AddrFamily::Netlink) as u16;
90        nladdr.nl_pid = pid.unwrap_or(0);
91        match unsafe {
92            libc::bind(
93                self.fd,
94                &nladdr as *const _ as *const libc::sockaddr,
95                size_of::<libc::sockaddr_nl>() as u32,
96            )
97        } {
98            i if i >= 0 => (),
99            _ => return Err(io::Error::last_os_error()),
100        };
101        self.add_mcast_membership(groups)?;
102        Ok(())
103    }
104
105    /// Set the size of the receive buffer for the socket.
106    ///
107    /// This can be useful when communicating with a service that sends a high volume of
108    /// messages (especially multicast), and your application cannot process them fast enough,
109    /// leading to the kernel dropping messages. A larger buffer may help mitigate this.
110    ///
111    /// The value passed is a hint to the kernel to set the size of the receive buffer.
112    /// The kernel will double the value provided to account for bookkeeping overhead.
113    /// The doubled value is capped by the value in `/proc/sys/net/core/rmem_max`.
114    ///
115    /// The default value is `/proc/sys/net/core/rmem_default`
116    ///
117    /// See `socket(7)` documentation for `SO_RCVBUF` for more information.
118    pub fn set_recv_buffer_size(&self, size: usize) -> Result<(), io::Error> {
119        let size = size as c_int;
120        match unsafe {
121            libc::setsockopt(
122                self.fd,
123                libc::SOL_SOCKET,
124                libc::SO_RCVBUF,
125                &size as *const _ as *const c_void,
126                size_of::<c_int>() as libc::socklen_t,
127            )
128        } {
129            0 => Ok(()),
130            _ => Err(io::Error::last_os_error()),
131        }
132    }
133
134    /// Join multicast groups for a socket.
135    pub fn add_mcast_membership(&self, groups: Groups) -> Result<(), io::Error> {
136        for group in groups.as_groups() {
137            match unsafe {
138                libc::setsockopt(
139                    self.fd,
140                    libc::SOL_NETLINK,
141                    libc::NETLINK_ADD_MEMBERSHIP,
142                    &group as *const _ as *const libc::c_void,
143                    size_of::<u32>() as libc::socklen_t,
144                )
145            } {
146                0 => (),
147                _ => return Err(io::Error::last_os_error()),
148            }
149        }
150        Ok(())
151    }
152
153    /// Leave multicast groups for a socket.
154    pub fn drop_mcast_membership(&self, groups: Groups) -> Result<(), io::Error> {
155        for group in groups.as_groups() {
156            match unsafe {
157                libc::setsockopt(
158                    self.fd,
159                    libc::SOL_NETLINK,
160                    libc::NETLINK_DROP_MEMBERSHIP,
161                    &group as *const _ as *const libc::c_void,
162                    size_of::<u32>() as libc::socklen_t,
163                )
164            } {
165                0 => (),
166                _ => return Err(io::Error::last_os_error()),
167            }
168        }
169        Ok(())
170    }
171
172    /// List joined groups for a socket.
173    pub fn list_mcast_membership(&self) -> Result<NetlinkBitArray, io::Error> {
174        let mut bit_array = NetlinkBitArray::new(4);
175        let mut len: libc::socklen_t = bit_array.len() as libc::socklen_t;
176        if unsafe {
177            libc::getsockopt(
178                self.fd,
179                libc::SOL_NETLINK,
180                libc::NETLINK_LIST_MEMBERSHIPS,
181                bit_array.as_mut_slice() as *mut _ as *mut libc::c_void,
182                &mut len as *mut _ as *mut libc::socklen_t,
183            )
184        } != 0
185        {
186            return Err(io::Error::last_os_error());
187        }
188        if len > bit_array.len() as libc::socklen_t {
189            bit_array.resize(len as usize);
190            if unsafe {
191                libc::getsockopt(
192                    self.fd,
193                    libc::SOL_NETLINK,
194                    libc::NETLINK_LIST_MEMBERSHIPS,
195                    bit_array.as_mut_slice() as *mut _ as *mut libc::c_void,
196                    &mut len as *mut _ as *mut libc::socklen_t,
197                )
198            } != 0
199            {
200                return Err(io::Error::last_os_error());
201            }
202        }
203        Ok(bit_array)
204    }
205
206    /// Send message encoded as byte slice to the netlink ID
207    /// specified in the netlink header
208    /// [`Nlmsghdr`][crate::nl::Nlmsghdr]
209    pub fn send<B>(&self, buf: B, flags: Msg) -> Result<libc::size_t, io::Error>
210    where
211        B: AsRef<[u8]>,
212    {
213        match unsafe {
214            libc::send(
215                self.fd,
216                buf.as_ref() as *const _ as *const c_void,
217                buf.as_ref().len(),
218                flags.bits() as i32,
219            )
220        } {
221            i if i >= 0 => Ok(i as libc::size_t),
222            _ => Err(io::Error::last_os_error()),
223        }
224    }
225
226    /// Receive message encoded as byte slice from the netlink socket.
227    pub fn recv<B>(&self, mut buf: B, flags: Msg) -> Result<(libc::size_t, Groups), io::Error>
228    where
229        B: AsMut<[u8]>,
230    {
231        let mut addr = unsafe { std::mem::zeroed::<sockaddr_nl>() };
232        let mut size: u32 = size_of::<sockaddr_nl>().try_into().unwrap_or(0);
233        match unsafe {
234            libc::recvfrom(
235                self.fd,
236                buf.as_mut() as *mut _ as *mut c_void,
237                buf.as_mut().len(),
238                flags.bits() as i32,
239                &mut addr as *mut _ as *mut sockaddr,
240                &mut size,
241            )
242        } {
243            i if i >= 0 => Ok((i as libc::size_t, Groups::new_bitmask(addr.nl_groups))),
244            i if i == -libc::EWOULDBLOCK as isize => {
245                Err(io::Error::from(io::ErrorKind::WouldBlock))
246            }
247            _ => Err(io::Error::last_os_error()),
248        }
249    }
250
251    /// Get the PID for this socket.
252    pub fn pid(&self) -> Result<u32, io::Error> {
253        let mut sock_len = size_of::<libc::sockaddr_nl>() as u32;
254        let mut sock_addr: MaybeUninit<libc::sockaddr_nl> = MaybeUninit::uninit();
255        match unsafe {
256            libc::getsockname(
257                self.fd,
258                sock_addr.as_mut_ptr() as *mut _,
259                &mut sock_len as *mut _,
260            )
261        } {
262            i if i >= 0 => Ok(unsafe { sock_addr.assume_init() }.nl_pid),
263            _ => Err(io::Error::last_os_error()),
264        }
265    }
266
267    /// If [`true`] is passed in, enable extended ACKs for this socket. If [`false`]
268    /// is passed in, disable extended ACKs for this socket.
269    pub fn enable_ext_ack(&self, enable: bool) -> Result<(), io::Error> {
270        match unsafe {
271            libc::setsockopt(
272                self.fd,
273                libc::SOL_NETLINK,
274                libc::NETLINK_EXT_ACK,
275                &c_int::from(enable) as *const _ as *const libc::c_void,
276                size_of::<i32>() as libc::socklen_t,
277            )
278        } {
279            0 => Ok(()),
280            _ => Err(io::Error::last_os_error()),
281        }
282    }
283
284    /// Return [`true`] if an extended ACK is enabled for this socket.
285    pub fn get_ext_ack_enabled(&self) -> Result<bool, io::Error> {
286        let mut sock_len = size_of::<libc::c_int>() as libc::socklen_t;
287        let mut sock_val: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
288        match unsafe {
289            libc::getsockopt(
290                self.fd,
291                libc::SOL_NETLINK,
292                libc::NETLINK_EXT_ACK,
293                &mut sock_val as *mut _ as *mut libc::c_void,
294                &mut sock_len as *mut _ as *mut libc::socklen_t,
295            )
296        } {
297            0 => Ok(unsafe { sock_val.assume_init() } != 0),
298            _ => Err(io::Error::last_os_error()),
299        }
300    }
301
302    /// If [`true`] is passed in, enable strict checking for this socket. If [`false`]
303    /// is passed in, disable strict checking for for this socket.
304    /// Only supported by `NlFamily::Route` sockets.
305    /// Requires Linux >= 4.20.
306    pub fn enable_strict_checking(&self, enable: bool) -> Result<(), io::Error> {
307        match unsafe {
308            libc::setsockopt(
309                self.fd,
310                libc::SOL_NETLINK,
311                libc::NETLINK_GET_STRICT_CHK,
312                &libc::c_int::from(enable) as *const _ as *const libc::c_void,
313                size_of::<libc::c_int>() as libc::socklen_t,
314            )
315        } {
316            0 => Ok(()),
317            _ => Err(io::Error::last_os_error()),
318        }
319    }
320
321    /// Return [`true`] if strict checking is enabled for this socket.
322    /// Only supported by `NlFamily::Route` sockets.
323    /// Requires Linux >= 4.20.
324    pub fn get_strict_checking_enabled(&self) -> Result<bool, io::Error> {
325        let mut sock_len = size_of::<libc::c_int>() as libc::socklen_t;
326        let mut sock_val: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
327        match unsafe {
328            libc::getsockopt(
329                self.fd,
330                libc::SOL_NETLINK,
331                libc::NETLINK_GET_STRICT_CHK,
332                &mut sock_val as *mut _ as *mut libc::c_void,
333                &mut sock_len as *mut _ as *mut libc::socklen_t,
334            )
335        } {
336            0 => Ok(unsafe { sock_val.assume_init() } != 0),
337            _ => Err(io::Error::last_os_error()),
338        }
339    }
340}
341
342#[cfg(feature = "sync")]
343impl From<synchronous::NlSocketHandle> for NlSocket {
344    fn from(s: synchronous::NlSocketHandle) -> Self {
345        s.socket
346    }
347}
348
349#[cfg(feature = "async")]
350impl From<asynchronous::NlSocketHandle> for NlSocket {
351    fn from(s: asynchronous::NlSocketHandle) -> Self {
352        s.socket.into_inner()
353    }
354}
355
356impl AsRawFd for NlSocket {
357    fn as_raw_fd(&self) -> RawFd {
358        self.fd
359    }
360}
361
362impl IntoRawFd for NlSocket {
363    fn into_raw_fd(self) -> RawFd {
364        let fd = self.fd;
365        std::mem::forget(self);
366        fd
367    }
368}
369
370impl FromRawFd for NlSocket {
371    unsafe fn from_raw_fd(fd: RawFd) -> Self {
372        NlSocket { fd }
373    }
374}
375
376impl Drop for NlSocket {
377    /// Closes underlying file descriptor to avoid file descriptor
378    /// leaks.
379    fn drop(&mut self) {
380        unsafe {
381            libc::close(self.fd);
382        }
383    }
384}
385
386#[cfg(test)]
387mod test {
388    use super::*;
389
390    use crate::test::setup;
391
392    #[test]
393    fn real_test_pid() {
394        setup();
395
396        let s = NlSocket::connect(NlFamily::Generic, Some(5555), Groups::empty()).unwrap();
397        assert_eq!(s.pid().unwrap(), 5555);
398    }
399
400    #[test]
401    fn real_ext_ack() {
402        setup();
403
404        let s = NlSocket::connect(NlFamily::Generic, None, Groups::empty()).unwrap();
405        assert!(!s.get_ext_ack_enabled().unwrap());
406        s.enable_ext_ack(true).unwrap();
407        assert!(s.get_ext_ack_enabled().unwrap());
408    }
409
410    #[test]
411    fn real_strict_checking() {
412        setup();
413
414        let s = NlSocket::connect(NlFamily::Route, None, Groups::empty()).unwrap();
415        assert!(!s.get_strict_checking_enabled().unwrap());
416        s.enable_strict_checking(true).unwrap();
417        assert!(s.get_strict_checking_enabled().unwrap());
418    }
419}