1use std::{
2 io,
3 mem::{MaybeUninit, size_of, zeroed},
4 os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
5};
6
7use libc::{c_int, c_void, sockaddr, sockaddr_nl};
8
9#[cfg(feature = "async")]
10use crate::socket::asynchronous;
11#[cfg(feature = "sync")]
12use crate::socket::synchronous;
13use crate::{
14 consts::socket::*,
15 utils::{Groups, NetlinkBitArray},
16};
17
18pub struct NlSocket {
20 fd: c_int,
21}
22
23impl NlSocket {
24 pub fn new(proto: NlFamily) -> Result<Self, io::Error> {
27 let fd = match unsafe {
28 libc::socket(
29 AddrFamily::Netlink.into(),
30 libc::SOCK_RAW | libc::SOCK_CLOEXEC,
31 proto.into(),
32 )
33 } {
34 i if i >= 0 => Ok(i),
35 _ => Err(io::Error::last_os_error()),
36 }?;
37 Ok(NlSocket { fd })
38 }
39
40 pub fn connect(proto: NlFamily, pid: Option<u32>, groups: Groups) -> Result<Self, io::Error> {
42 let s = NlSocket::new(proto)?;
43 s.bind(pid, groups)?;
44 Ok(s)
45 }
46
47 pub fn block(&self) -> Result<(), io::Error> {
49 match unsafe {
50 libc::fcntl(
51 self.fd,
52 libc::F_SETFL,
53 libc::fcntl(self.fd, libc::F_GETFL, 0) & !libc::O_NONBLOCK,
54 )
55 } {
56 i if i < 0 => Err(io::Error::last_os_error()),
57 _ => Ok(()),
58 }
59 }
60
61 pub fn nonblock(&self) -> Result<(), io::Error> {
63 match unsafe {
64 libc::fcntl(
65 self.fd,
66 libc::F_SETFL,
67 libc::fcntl(self.fd, libc::F_GETFL, 0) | libc::O_NONBLOCK,
68 )
69 } {
70 i if i < 0 => Err(io::Error::last_os_error()),
71 _ => Ok(()),
72 }
73 }
74
75 pub fn is_blocking(&self) -> Result<bool, io::Error> {
77 let is_blocking = match unsafe { libc::fcntl(self.fd, libc::F_GETFL, 0) } {
78 i if i >= 0 => i & libc::O_NONBLOCK == 0,
79 _ => return Err(io::Error::last_os_error()),
80 };
81 Ok(is_blocking)
82 }
83
84 pub fn bind(&self, pid: Option<u32>, groups: Groups) -> Result<(), io::Error> {
88 let mut nladdr = unsafe { zeroed::<libc::sockaddr_nl>() };
89 nladdr.nl_family = c_int::from(AddrFamily::Netlink) as u16;
90 nladdr.nl_pid = pid.unwrap_or(0);
91 match unsafe {
92 libc::bind(
93 self.fd,
94 &nladdr as *const _ as *const libc::sockaddr,
95 size_of::<libc::sockaddr_nl>() as u32,
96 )
97 } {
98 i if i >= 0 => (),
99 _ => return Err(io::Error::last_os_error()),
100 };
101 self.add_mcast_membership(groups)?;
102 Ok(())
103 }
104
105 pub fn set_recv_buffer_size(&self, size: usize) -> Result<(), io::Error> {
119 let size = size as c_int;
120 match unsafe {
121 libc::setsockopt(
122 self.fd,
123 libc::SOL_SOCKET,
124 libc::SO_RCVBUF,
125 &size as *const _ as *const c_void,
126 size_of::<c_int>() as libc::socklen_t,
127 )
128 } {
129 0 => Ok(()),
130 _ => Err(io::Error::last_os_error()),
131 }
132 }
133
134 pub fn add_mcast_membership(&self, groups: Groups) -> Result<(), io::Error> {
136 for group in groups.as_groups() {
137 match unsafe {
138 libc::setsockopt(
139 self.fd,
140 libc::SOL_NETLINK,
141 libc::NETLINK_ADD_MEMBERSHIP,
142 &group as *const _ as *const libc::c_void,
143 size_of::<u32>() as libc::socklen_t,
144 )
145 } {
146 0 => (),
147 _ => return Err(io::Error::last_os_error()),
148 }
149 }
150 Ok(())
151 }
152
153 pub fn drop_mcast_membership(&self, groups: Groups) -> Result<(), io::Error> {
155 for group in groups.as_groups() {
156 match unsafe {
157 libc::setsockopt(
158 self.fd,
159 libc::SOL_NETLINK,
160 libc::NETLINK_DROP_MEMBERSHIP,
161 &group as *const _ as *const libc::c_void,
162 size_of::<u32>() as libc::socklen_t,
163 )
164 } {
165 0 => (),
166 _ => return Err(io::Error::last_os_error()),
167 }
168 }
169 Ok(())
170 }
171
172 pub fn list_mcast_membership(&self) -> Result<NetlinkBitArray, io::Error> {
174 let mut bit_array = NetlinkBitArray::new(4);
175 let mut len: libc::socklen_t = bit_array.len() as libc::socklen_t;
176 if unsafe {
177 libc::getsockopt(
178 self.fd,
179 libc::SOL_NETLINK,
180 libc::NETLINK_LIST_MEMBERSHIPS,
181 bit_array.as_mut_slice() as *mut _ as *mut libc::c_void,
182 &mut len as *mut _ as *mut libc::socklen_t,
183 )
184 } != 0
185 {
186 return Err(io::Error::last_os_error());
187 }
188 if len > bit_array.len() as libc::socklen_t {
189 bit_array.resize(len as usize);
190 if unsafe {
191 libc::getsockopt(
192 self.fd,
193 libc::SOL_NETLINK,
194 libc::NETLINK_LIST_MEMBERSHIPS,
195 bit_array.as_mut_slice() as *mut _ as *mut libc::c_void,
196 &mut len as *mut _ as *mut libc::socklen_t,
197 )
198 } != 0
199 {
200 return Err(io::Error::last_os_error());
201 }
202 }
203 Ok(bit_array)
204 }
205
206 pub fn send<B>(&self, buf: B, flags: Msg) -> Result<libc::size_t, io::Error>
210 where
211 B: AsRef<[u8]>,
212 {
213 match unsafe {
214 libc::send(
215 self.fd,
216 buf.as_ref() as *const _ as *const c_void,
217 buf.as_ref().len(),
218 flags.bits() as i32,
219 )
220 } {
221 i if i >= 0 => Ok(i as libc::size_t),
222 _ => Err(io::Error::last_os_error()),
223 }
224 }
225
226 pub fn recv<B>(&self, mut buf: B, flags: Msg) -> Result<(libc::size_t, Groups), io::Error>
228 where
229 B: AsMut<[u8]>,
230 {
231 let mut addr = unsafe { std::mem::zeroed::<sockaddr_nl>() };
232 let mut size: u32 = size_of::<sockaddr_nl>().try_into().unwrap_or(0);
233 match unsafe {
234 libc::recvfrom(
235 self.fd,
236 buf.as_mut() as *mut _ as *mut c_void,
237 buf.as_mut().len(),
238 flags.bits() as i32,
239 &mut addr as *mut _ as *mut sockaddr,
240 &mut size,
241 )
242 } {
243 i if i >= 0 => Ok((i as libc::size_t, Groups::new_bitmask(addr.nl_groups))),
244 i if i == -libc::EWOULDBLOCK as isize => {
245 Err(io::Error::from(io::ErrorKind::WouldBlock))
246 }
247 _ => Err(io::Error::last_os_error()),
248 }
249 }
250
251 pub fn pid(&self) -> Result<u32, io::Error> {
253 let mut sock_len = size_of::<libc::sockaddr_nl>() as u32;
254 let mut sock_addr: MaybeUninit<libc::sockaddr_nl> = MaybeUninit::uninit();
255 match unsafe {
256 libc::getsockname(
257 self.fd,
258 sock_addr.as_mut_ptr() as *mut _,
259 &mut sock_len as *mut _,
260 )
261 } {
262 i if i >= 0 => Ok(unsafe { sock_addr.assume_init() }.nl_pid),
263 _ => Err(io::Error::last_os_error()),
264 }
265 }
266
267 pub fn enable_ext_ack(&self, enable: bool) -> Result<(), io::Error> {
270 match unsafe {
271 libc::setsockopt(
272 self.fd,
273 libc::SOL_NETLINK,
274 libc::NETLINK_EXT_ACK,
275 &c_int::from(enable) as *const _ as *const libc::c_void,
276 size_of::<i32>() as libc::socklen_t,
277 )
278 } {
279 0 => Ok(()),
280 _ => Err(io::Error::last_os_error()),
281 }
282 }
283
284 pub fn get_ext_ack_enabled(&self) -> Result<bool, io::Error> {
286 let mut sock_len = size_of::<libc::c_int>() as libc::socklen_t;
287 let mut sock_val: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
288 match unsafe {
289 libc::getsockopt(
290 self.fd,
291 libc::SOL_NETLINK,
292 libc::NETLINK_EXT_ACK,
293 &mut sock_val as *mut _ as *mut libc::c_void,
294 &mut sock_len as *mut _ as *mut libc::socklen_t,
295 )
296 } {
297 0 => Ok(unsafe { sock_val.assume_init() } != 0),
298 _ => Err(io::Error::last_os_error()),
299 }
300 }
301
302 pub fn enable_strict_checking(&self, enable: bool) -> Result<(), io::Error> {
307 match unsafe {
308 libc::setsockopt(
309 self.fd,
310 libc::SOL_NETLINK,
311 libc::NETLINK_GET_STRICT_CHK,
312 &libc::c_int::from(enable) as *const _ as *const libc::c_void,
313 size_of::<libc::c_int>() as libc::socklen_t,
314 )
315 } {
316 0 => Ok(()),
317 _ => Err(io::Error::last_os_error()),
318 }
319 }
320
321 pub fn get_strict_checking_enabled(&self) -> Result<bool, io::Error> {
325 let mut sock_len = size_of::<libc::c_int>() as libc::socklen_t;
326 let mut sock_val: MaybeUninit<libc::c_int> = MaybeUninit::uninit();
327 match unsafe {
328 libc::getsockopt(
329 self.fd,
330 libc::SOL_NETLINK,
331 libc::NETLINK_GET_STRICT_CHK,
332 &mut sock_val as *mut _ as *mut libc::c_void,
333 &mut sock_len as *mut _ as *mut libc::socklen_t,
334 )
335 } {
336 0 => Ok(unsafe { sock_val.assume_init() } != 0),
337 _ => Err(io::Error::last_os_error()),
338 }
339 }
340}
341
342#[cfg(feature = "sync")]
343impl From<synchronous::NlSocketHandle> for NlSocket {
344 fn from(s: synchronous::NlSocketHandle) -> Self {
345 s.socket
346 }
347}
348
349#[cfg(feature = "async")]
350impl From<asynchronous::NlSocketHandle> for NlSocket {
351 fn from(s: asynchronous::NlSocketHandle) -> Self {
352 s.socket.into_inner()
353 }
354}
355
356impl AsRawFd for NlSocket {
357 fn as_raw_fd(&self) -> RawFd {
358 self.fd
359 }
360}
361
362impl IntoRawFd for NlSocket {
363 fn into_raw_fd(self) -> RawFd {
364 let fd = self.fd;
365 std::mem::forget(self);
366 fd
367 }
368}
369
370impl FromRawFd for NlSocket {
371 unsafe fn from_raw_fd(fd: RawFd) -> Self {
372 NlSocket { fd }
373 }
374}
375
376impl Drop for NlSocket {
377 fn drop(&mut self) {
380 unsafe {
381 libc::close(self.fd);
382 }
383 }
384}
385
386#[cfg(test)]
387mod test {
388 use super::*;
389
390 use crate::test::setup;
391
392 #[test]
393 fn real_test_pid() {
394 setup();
395
396 let s = NlSocket::connect(NlFamily::Generic, Some(5555), Groups::empty()).unwrap();
397 assert_eq!(s.pid().unwrap(), 5555);
398 }
399
400 #[test]
401 fn real_ext_ack() {
402 setup();
403
404 let s = NlSocket::connect(NlFamily::Generic, None, Groups::empty()).unwrap();
405 assert!(!s.get_ext_ack_enabled().unwrap());
406 s.enable_ext_ack(true).unwrap();
407 assert!(s.get_ext_ack_enabled().unwrap());
408 }
409
410 #[test]
411 fn real_strict_checking() {
412 setup();
413
414 let s = NlSocket::connect(NlFamily::Route, None, Groups::empty()).unwrap();
415 assert!(!s.get_strict_checking_enabled().unwrap());
416 s.enable_strict_checking(true).unwrap();
417 assert!(s.get_strict_checking_enabled().unwrap());
418 }
419}