neli/socket/
synchronous.rs

1use std::{
2    fmt::Debug,
3    io::Cursor,
4    os::unix::io::{AsRawFd, IntoRawFd, RawFd},
5};
6
7use log::trace;
8
9use crate::{
10    FromBytesWithInput, Size, ToBytes,
11    consts::{nl::*, socket::*},
12    err::SocketError,
13    iter::NlBufferIter,
14    nl::Nlmsghdr,
15    socket::shared::NlSocket,
16    types::NlBuffer,
17    utils::{
18        Groups, NetlinkBitArray,
19        synchronous::{BufferPool, BufferPoolGuard},
20    },
21};
22
23/// Higher level handle for socket operations.
24pub struct NlSocketHandle {
25    pub(super) socket: NlSocket,
26    pid: u32,
27    pool: BufferPool,
28}
29
30impl NlSocketHandle {
31    /// Equivalent of `socket` and `bind` calls.
32    pub fn connect(proto: NlFamily, pid: Option<u32>, groups: Groups) -> Result<Self, SocketError> {
33        let socket = NlSocket::connect(proto, pid, groups)?;
34        socket.block()?;
35        let pid = socket.pid()?;
36        Ok(NlSocketHandle {
37            socket,
38            pid,
39            pool: BufferPool::default(),
40        })
41    }
42
43    /// Join multicast groups for a socket.
44    pub fn add_mcast_membership(&self, groups: Groups) -> Result<(), SocketError> {
45        self.socket
46            .add_mcast_membership(groups)
47            .map_err(SocketError::from)
48    }
49
50    /// Leave multicast groups for a socket.
51    pub fn drop_mcast_membership(&self, groups: Groups) -> Result<(), SocketError> {
52        self.socket
53            .drop_mcast_membership(groups)
54            .map_err(SocketError::from)
55    }
56
57    /// List joined groups for a socket.
58    pub fn list_mcast_membership(&self) -> Result<NetlinkBitArray, SocketError> {
59        self.socket
60            .list_mcast_membership()
61            .map_err(SocketError::from)
62    }
63
64    /// Get the PID for the current socket.
65    pub fn pid(&self) -> u32 {
66        self.pid
67    }
68
69    /// Convenience function to send an [`Nlmsghdr`] struct
70    pub fn send<T, P>(&self, msg: &Nlmsghdr<T, P>) -> Result<(), SocketError>
71    where
72        T: NlType + Debug,
73        P: Size + ToBytes + Debug,
74    {
75        trace!("Message sent:\n{msg:?}");
76
77        let mut buffer = Cursor::new(vec![0; msg.padded_size()]);
78        msg.to_bytes(&mut buffer)?;
79        trace!("Buffer sent: {:?}", buffer.get_ref());
80        self.socket.send(buffer.get_ref(), Msg::empty())?;
81
82        Ok(())
83    }
84
85    /// Convenience function to read a stream of [`Nlmsghdr`]
86    /// structs one by one using an iterator.
87    ///
88    /// Returns [`None`] when the stream of messages has been completely processed in
89    /// the current buffer resulting from a single
90    /// [`NlSocket::recv`][crate::socket::NlSocket::recv] call.
91    ///
92    /// See [`NlBufferIter`] for more detailed information.
93    pub fn recv<T, P>(
94        &self,
95    ) -> Result<(NlBufferIter<T, P, BufferPoolGuard<'_>>, Groups), SocketError>
96    where
97        T: NlType + Debug,
98        P: Size + FromBytesWithInput<Input = usize> + Debug,
99    {
100        let mut buffer = self.pool.acquire();
101        let (mem_read, groups) = self.socket.recv(&mut buffer, Msg::empty())?;
102        buffer.reduce_size(mem_read);
103        trace!("Buffer received: {:?}", buffer.as_ref());
104        Ok((NlBufferIter::new(Cursor::new(buffer)), groups))
105    }
106
107    /// Parse all [`Nlmsghdr`] structs sent in
108    /// one network packet and return them all in a list.
109    ///
110    /// Failure to parse any packet will cause the entire operation
111    /// to fail. If an error is detected at the application level,
112    /// this method will discard any non-error
113    /// [`Nlmsghdr`] structs and only return the
114    /// error. For a more granular approach, use [`NlSocketHandle::recv`].
115    pub fn recv_all<T, P>(&self) -> Result<(NlBuffer<T, P>, Groups), SocketError>
116    where
117        T: NlType + Debug,
118        P: Size + FromBytesWithInput<Input = usize> + Debug,
119    {
120        let mut buffer = self.pool.acquire();
121        let (mem_read, groups) = self.socket.recv(&mut buffer, Msg::empty())?;
122        if mem_read == 0 {
123            return Ok((NlBuffer::new(), Groups::empty()));
124        }
125        buffer.reduce_size(mem_read);
126
127        let vec = NlBuffer::from_bytes_with_input(&mut Cursor::new(buffer), mem_read)?;
128
129        trace!("Messages received: {vec:?}");
130
131        Ok((vec, groups))
132    }
133
134    /// Set the size of the receive buffer for the socket.
135    ///
136    /// This can be useful when communicating with a service that sends a high volume of
137    /// messages (especially multicast), and your application cannot process them fast enough,
138    /// leading to the kernel dropping messages. A larger buffer may help mitigate this.
139    ///
140    /// The value passed is a hint to the kernel to set the size of the receive buffer.
141    /// The kernel will double the value provided to account for bookkeeping overhead.
142    /// The doubled value is capped by the value in `/proc/sys/net/core/rmem_max`.
143    ///
144    /// The default value is `/proc/sys/net/core/rmem_default`
145    ///
146    /// See `socket(7)` documentation for `SO_RCVBUF` for more information.
147    pub fn set_recv_buffer_size(&self, size: usize) -> Result<(), SocketError> {
148        self.socket
149            .set_recv_buffer_size(size)
150            .map_err(SocketError::from)
151    }
152
153    /// If [`true`] is passed in, enable extended ACKs for this socket. If [`false`]
154    /// is passed in, disable extended ACKs for this socket.
155    pub fn enable_ext_ack(&self, enable: bool) -> Result<(), SocketError> {
156        self.socket
157            .enable_ext_ack(enable)
158            .map_err(SocketError::from)
159    }
160
161    /// Return [`true`] if an extended ACK is enabled for this socket.
162    pub fn get_ext_ack_enabled(&self) -> Result<bool, SocketError> {
163        self.socket.get_ext_ack_enabled().map_err(SocketError::from)
164    }
165
166    /// If [`true`] is passed in, enable strict checking for this socket. If [`false`]
167    /// is passed in, disable strict checking for for this socket.
168    /// Only supported by `NlFamily::Route` sockets.
169    /// Requires Linux >= 4.20.
170    pub fn enable_strict_checking(&self, enable: bool) -> Result<(), SocketError> {
171        self.socket
172            .enable_strict_checking(enable)
173            .map_err(SocketError::from)
174    }
175
176    /// Return [`true`] if strict checking is enabled for this socket.
177    /// Only supported by `NlFamily::Route` sockets.
178    /// Requires Linux >= 4.20.
179    pub fn get_strict_checking_enabled(&self) -> Result<bool, SocketError> {
180        self.socket
181            .get_strict_checking_enabled()
182            .map_err(SocketError::from)
183    }
184}
185
186impl AsRawFd for NlSocketHandle {
187    fn as_raw_fd(&self) -> RawFd {
188        self.socket.as_raw_fd()
189    }
190}
191
192impl IntoRawFd for NlSocketHandle {
193    fn into_raw_fd(self) -> RawFd {
194        self.socket.into_raw_fd()
195    }
196}