neli/
socket.rs

1//! This module provides code that glues all of the other modules
2//! together and allows message send and receive operations.
3//!
4//! ## Important methods
5//! * [`NlSocket::send`] and [`NlSocket::recv`] methods are meant to
6//! be the most low level calls. They essentially do what the C
7//! system calls `send` and `recv` do with very little abstraction.
8//! * [`NlSocketHandle::send`] and [`NlSocketHandle::recv`] methods
9//! are meant to provide an interface that is more idiomatic for
10//! the library.
11//! * [`NlSocketHandle::iter`] provides a loop based iteration
12//! through messages that are received in a stream over the socket.
13//!
14//! ## Features
15//! The `async` feature exposed by `cargo` allows the socket to use
16//! Rust's [tokio](https://tokio.rs) for async IO.
17//!
18//! ## Additional methods
19//!
20//! There are methods for blocking and non-blocking, resolving
21//! generic netlink multicast group IDs, and other convenience
22//! functions so see if your use case is supported. If it isn't,
23//! please open a Github issue and submit a feature request.
24//!
25//! ## Design decisions
26//!
27//! The buffer allocated in the [`NlSocketHandle`] structure should
28//! be allocated on the heap. This is intentional as a buffer
29//! that large could be a problem on the stack.
30
31use std::{
32    fmt::Debug,
33    io::{self, Cursor},
34    mem::{size_of, zeroed, MaybeUninit},
35    os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
36};
37
38use libc::{self, c_int, c_void};
39use log::debug;
40
41use crate::{
42    consts::{genl::*, nl::*, socket::*, MAX_NL_LENGTH},
43    err::{NlError, SerError},
44    genl::{Genlmsghdr, Nlattr},
45    iter::{IterationBehavior, NlMessageIter},
46    nl::{NlPayload, Nlmsghdr},
47    parse::packet_length_u32,
48    types::{GenlBuffer, NlBuffer},
49    utils::NetlinkBitArray,
50    FromBytes, FromBytesWithInput, ToBytes,
51};
52
53/// Low level access to a netlink socket.
54pub struct NlSocket {
55    fd: c_int,
56}
57
58impl NlSocket {
59    /// Wrapper around `socket()` syscall filling in the
60    /// netlink-specific information.
61    pub fn new(proto: NlFamily) -> Result<Self, io::Error> {
62        let fd =
63            match unsafe { libc::socket(AddrFamily::Netlink.into(), libc::SOCK_RAW, proto.into()) }
64            {
65                i if i >= 0 => Ok(i),
66                _ => Err(io::Error::last_os_error()),
67            }?;
68        Ok(NlSocket { fd })
69    }
70
71    /// Equivalent of `socket` and `bind` calls.
72    pub fn connect(proto: NlFamily, pid: Option<u32>, groups: &[u32]) -> Result<Self, io::Error> {
73        let s = NlSocket::new(proto)?;
74        s.bind(pid, groups)?;
75        Ok(s)
76    }
77
78    /// Set underlying socket file descriptor to be blocking.
79    pub fn block(&self) -> Result<(), io::Error> {
80        match unsafe {
81            libc::fcntl(
82                self.fd,
83                libc::F_SETFL,
84                libc::fcntl(self.fd, libc::F_GETFL, 0) & !libc::O_NONBLOCK,
85            )
86        } {
87            i if i < 0 => Err(io::Error::last_os_error()),
88            _ => Ok(()),
89        }
90    }
91
92    /// Set underlying socket file descriptor to be non blocking.
93    pub fn nonblock(&self) -> Result<(), io::Error> {
94        match unsafe {
95            libc::fcntl(
96                self.fd,
97                libc::F_SETFL,
98                libc::fcntl(self.fd, libc::F_GETFL, 0) | libc::O_NONBLOCK,
99            )
100        } {
101            i if i < 0 => Err(io::Error::last_os_error()),
102            _ => Ok(()),
103        }
104    }
105
106    /// Determines if underlying file descriptor is blocking.
107    pub fn is_blocking(&self) -> Result<bool, io::Error> {
108        let is_blocking = match unsafe { libc::fcntl(self.fd, libc::F_GETFL, 0) } {
109            i if i >= 0 => i & libc::O_NONBLOCK == 0,
110            _ => return Err(io::Error::last_os_error()),
111        };
112        Ok(is_blocking)
113    }
114
115    /// Use this function to bind to a netlink ID and subscribe to
116    /// groups. See netlink(7) man pages for more information on
117    /// netlink IDs and groups.
118    pub fn bind(&self, pid: Option<u32>, groups: &[u32]) -> Result<(), io::Error> {
119        let mut nladdr = unsafe { zeroed::<libc::sockaddr_nl>() };
120        nladdr.nl_family = libc::c_int::from(AddrFamily::Netlink) as u16;
121        nladdr.nl_pid = pid.unwrap_or(0);
122        nladdr.nl_groups = 0;
123        match unsafe {
124            libc::bind(
125                self.fd,
126                &nladdr as *const _ as *const libc::sockaddr,
127                size_of::<libc::sockaddr_nl>() as u32,
128            )
129        } {
130            i if i >= 0 => (),
131            _ => return Err(io::Error::last_os_error()),
132        };
133        if !groups.is_empty() {
134            self.add_mcast_membership(groups)?;
135        }
136        Ok(())
137    }
138
139    /// Join multicast groups for a socket.
140    pub fn add_mcast_membership(&self, groups: &[u32]) -> Result<(), io::Error> {
141        for group in groups {
142            match unsafe {
143                libc::setsockopt(
144                    self.fd,
145                    libc::SOL_NETLINK,
146                    libc::NETLINK_ADD_MEMBERSHIP,
147                    group as *const _ as *const libc::c_void,
148                    size_of::<u32>() as libc::socklen_t,
149                )
150            } {
151                i if i == 0 => (),
152                _ => return Err(io::Error::last_os_error()),
153            }
154        }
155        Ok(())
156    }
157
158    /// Leave multicast groups for a socket.
159    pub fn drop_mcast_membership(&self, groups: &[u32]) -> Result<(), io::Error> {
160        for group in groups {
161            match unsafe {
162                libc::setsockopt(
163                    self.fd,
164                    libc::SOL_NETLINK,
165                    libc::NETLINK_DROP_MEMBERSHIP,
166                    group as *const _ as *const libc::c_void,
167                    size_of::<u32>() as libc::socklen_t,
168                )
169            } {
170                i if i == 0 => (),
171                _ => return Err(io::Error::last_os_error()),
172            }
173        }
174        Ok(())
175    }
176
177    /// List joined groups for a socket.
178    pub fn list_mcast_membership(&self) -> Result<NetlinkBitArray, io::Error> {
179        let mut bit_array = NetlinkBitArray::new(4);
180        let mut len = bit_array.len();
181        if unsafe {
182            libc::getsockopt(
183                self.fd,
184                libc::SOL_NETLINK,
185                libc::NETLINK_LIST_MEMBERSHIPS,
186                bit_array.as_mut_slice() as *mut _ as *mut libc::c_void,
187                &mut len as *mut _ as *mut libc::socklen_t,
188            )
189        } != 0
190        {
191            return Err(io::Error::last_os_error());
192        }
193        if len > bit_array.len() {
194            bit_array.resize(len);
195            if unsafe {
196                libc::getsockopt(
197                    self.fd,
198                    libc::SOL_NETLINK,
199                    libc::NETLINK_LIST_MEMBERSHIPS,
200                    bit_array.as_mut_slice() as *mut _ as *mut libc::c_void,
201                    &mut len as *mut _ as *mut libc::socklen_t,
202                )
203            } != 0
204            {
205                return Err(io::Error::last_os_error());
206            }
207        }
208        Ok(bit_array)
209    }
210
211    /// Send message encoded as byte slice to the netlink ID
212    /// specified in the netlink header
213    /// [`Nlmsghdr`][crate::nl::Nlmsghdr]
214    pub fn send<B>(&self, buf: B, flags: i32) -> Result<libc::size_t, io::Error>
215    where
216        B: AsRef<[u8]>,
217    {
218        match unsafe {
219            libc::send(
220                self.fd,
221                buf.as_ref() as *const _ as *const c_void,
222                buf.as_ref().len(),
223                flags,
224            )
225        } {
226            i if i >= 0 => Ok(i as libc::size_t),
227            _ => Err(io::Error::last_os_error()),
228        }
229    }
230
231    /// Receive message encoded as byte slice from the netlink socket.
232    pub fn recv<B>(&self, mut buf: B, flags: i32) -> Result<libc::size_t, io::Error>
233    where
234        B: AsMut<[u8]>,
235    {
236        match unsafe {
237            libc::recv(
238                self.fd,
239                buf.as_mut() as *mut _ as *mut c_void,
240                buf.as_mut().len(),
241                flags,
242            )
243        } {
244            i if i >= 0 => Ok(i as libc::size_t),
245            _ => Err(io::Error::last_os_error()),
246        }
247    }
248
249    /// Get the PID for this socket.
250    pub fn pid(&self) -> Result<u32, io::Error> {
251        let mut sock_len = size_of::<libc::sockaddr_nl>() as u32;
252        let mut sock_addr: MaybeUninit<libc::sockaddr_nl> = MaybeUninit::uninit();
253        match unsafe {
254            libc::getsockname(
255                self.fd,
256                sock_addr.as_mut_ptr() as *mut _,
257                &mut sock_len as *mut _,
258            )
259        } {
260            i if i >= 0 => Ok(unsafe { sock_addr.assume_init() }.nl_pid),
261            _ => Err(io::Error::last_os_error()),
262        }
263    }
264}
265
266impl From<NlSocketHandle> for NlSocket {
267    fn from(s: NlSocketHandle) -> Self {
268        s.socket
269    }
270}
271
272impl AsRawFd for NlSocket {
273    fn as_raw_fd(&self) -> RawFd {
274        self.fd
275    }
276}
277
278impl IntoRawFd for NlSocket {
279    fn into_raw_fd(self) -> RawFd {
280        let fd = self.fd;
281        std::mem::forget(self);
282        fd
283    }
284}
285
286impl FromRawFd for NlSocket {
287    unsafe fn from_raw_fd(fd: RawFd) -> Self {
288        NlSocket { fd }
289    }
290}
291
292impl Drop for NlSocket {
293    /// Closes underlying file descriptor to avoid file descriptor
294    /// leaks.
295    fn drop(&mut self) {
296        unsafe {
297            libc::close(self.fd);
298        }
299    }
300}
301
302/// Higher level handle for socket operations.
303pub struct NlSocketHandle {
304    socket: NlSocket,
305    buffer: Vec<u8>,
306    position: usize,
307    end: usize,
308    pub(super) needs_ack: bool,
309}
310
311type GenlFamily = Result<
312    NlBuffer<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>,
313    NlError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>,
314>;
315
316impl NlSocketHandle {
317    /// Wrapper around `socket()` syscall filling in the
318    /// netlink-specific information
319    pub fn new(proto: NlFamily) -> Result<Self, io::Error> {
320        Ok(NlSocketHandle {
321            socket: NlSocket::new(proto)?,
322            buffer: vec![0; MAX_NL_LENGTH],
323            position: 0,
324            end: 0,
325            needs_ack: false,
326        })
327    }
328
329    /// Equivalent of `socket` and `bind` calls.
330    pub fn connect(proto: NlFamily, pid: Option<u32>, groups: &[u32]) -> Result<Self, io::Error> {
331        Ok(NlSocketHandle {
332            socket: NlSocket::connect(proto, pid, groups)?,
333            buffer: vec![0; MAX_NL_LENGTH],
334            position: 0,
335            end: 0,
336            needs_ack: false,
337        })
338    }
339
340    /// Set underlying socket file descriptor to be blocking.
341    pub fn block(&self) -> Result<(), io::Error> {
342        self.socket.block()
343    }
344
345    /// Set underlying socket file descriptor to be non blocking.
346    pub fn nonblock(&self) -> Result<(), io::Error> {
347        self.socket.nonblock()
348    }
349
350    /// Determines if underlying file descriptor is blocking.
351    pub fn is_blocking(&self) -> Result<bool, io::Error> {
352        self.socket.is_blocking()
353    }
354
355    /// Use this function to bind to a netlink ID and subscribe to
356    /// groups. See netlink(7) man pages for more information on
357    /// netlink IDs and groups.
358    pub fn bind(&self, pid: Option<u32>, groups: &[u32]) -> Result<(), io::Error> {
359        self.socket.bind(pid, groups)
360    }
361
362    /// Join multicast groups for a socket.
363    pub fn add_mcast_membership(&self, groups: &[u32]) -> Result<(), io::Error> {
364        self.socket.add_mcast_membership(groups)
365    }
366
367    /// Leave multicast groups for a socket.
368    pub fn drop_mcast_membership(&self, groups: &[u32]) -> Result<(), io::Error> {
369        self.socket.drop_mcast_membership(groups)
370    }
371
372    /// List joined groups for a socket.
373    pub fn list_mcast_membership(&self) -> Result<NetlinkBitArray, io::Error> {
374        self.socket.list_mcast_membership()
375    }
376
377    /// Get the PID for the current socket.
378    pub fn pid(&self) -> Result<u32, io::Error> {
379        self.socket.pid()
380    }
381
382    fn get_genl_family(&mut self, family_name: &str) -> GenlFamily {
383        let mut attrs = GenlBuffer::new();
384        attrs.push(Nlattr::new(
385            false,
386            false,
387            CtrlAttr::FamilyName,
388            family_name,
389        )?);
390        let genlhdr = Genlmsghdr::new(CtrlCmd::Getfamily, 2, attrs);
391        let nlhdr = Nlmsghdr::new(
392            None,
393            GenlId::Ctrl,
394            NlmFFlags::new(&[NlmF::Request, NlmF::Ack]),
395            None,
396            None,
397            NlPayload::Payload(genlhdr),
398        );
399        self.send(nlhdr)?;
400
401        let mut buffer = NlBuffer::new();
402        for msg in self.iter(false) {
403            buffer.push(msg?);
404        }
405        Ok(buffer)
406    }
407
408    /// Convenience function for resolving a [`str`] containing the
409    /// generic netlink family name to a numeric generic netlink ID.
410    pub fn resolve_genl_family(
411        &mut self,
412        family_name: &str,
413    ) -> Result<u16, NlError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>> {
414        let mut res = Err(NlError::new(format!(
415            "Generic netlink family {} was not found",
416            family_name
417        )));
418
419        let nlhdrs = self.get_genl_family(family_name)?;
420        for nlhdr in nlhdrs.into_iter() {
421            if let NlPayload::Payload(p) = nlhdr.nl_payload {
422                let handle = p.get_attr_handle();
423                if let Ok(u) = handle.get_attr_payload_as::<u16>(CtrlAttr::FamilyId) {
424                    res = Ok(u);
425                }
426            }
427        }
428
429        res
430    }
431
432    /// Convenience function for resolving a [`str`] containing the
433    /// multicast group name to a numeric multicast group ID.
434    pub fn resolve_nl_mcast_group(
435        &mut self,
436        family_name: &str,
437        mcast_name: &str,
438    ) -> Result<u32, NlError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>> {
439        let mut res = Err(NlError::new(format!(
440            "Failed to resolve multicast group ID for family name {}, multicast group name {}",
441            family_name, mcast_name,
442        )));
443
444        let nlhdrs = self.get_genl_family(family_name)?;
445        for nlhdr in nlhdrs {
446            if let NlPayload::Payload(p) = nlhdr.nl_payload {
447                let mut handle = p.get_attr_handle();
448                let mcast_groups = handle.get_nested_attributes::<Index>(CtrlAttr::McastGroups)?;
449                if let Some(id) = mcast_groups.iter().find_map(|item| {
450                    let nested_attrs = item.get_attr_handle::<CtrlAttrMcastGrp>().ok()?;
451                    let string = nested_attrs
452                        .get_attr_payload_as_with_len::<String>(CtrlAttrMcastGrp::Name)
453                        .ok()?;
454                    if string.as_str() == mcast_name {
455                        nested_attrs
456                            .get_attr_payload_as::<u32>(CtrlAttrMcastGrp::Id)
457                            .ok()
458                    } else {
459                        None
460                    }
461                }) {
462                    res = Ok(id);
463                }
464            }
465        }
466
467        res
468    }
469
470    /// Look up netlink family and multicast group name by ID.
471    pub fn lookup_id(
472        &mut self,
473        id: u32,
474    ) -> Result<(String, String), NlError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>> {
475        let mut res = Err(NlError::new("ID does not correspond to a multicast group"));
476
477        let attrs = GenlBuffer::new();
478        let genlhdr = Genlmsghdr::<CtrlCmd, CtrlAttr>::new(CtrlCmd::Getfamily, 2, attrs);
479        let nlhdr = Nlmsghdr::new(
480            None,
481            GenlId::Ctrl,
482            NlmFFlags::new(&[NlmF::Request, NlmF::Dump]),
483            None,
484            None,
485            NlPayload::Payload(genlhdr),
486        );
487
488        self.send(nlhdr)?;
489        for res_msg in self.iter::<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>(false) {
490            let msg = res_msg?;
491
492            if let NlPayload::Payload(p) = msg.nl_payload {
493                let mut attributes = p.get_attr_handle();
494                let name =
495                    attributes.get_attr_payload_as_with_len::<String>(CtrlAttr::FamilyName)?;
496                let groups = match attributes.get_nested_attributes::<Index>(CtrlAttr::McastGroups)
497                {
498                    Ok(grps) => grps,
499                    Err(_) => continue,
500                };
501                for group_by_index in groups.iter() {
502                    let attributes = group_by_index.get_attr_handle::<CtrlAttrMcastGrp>()?;
503                    if let Ok(mcid) = attributes.get_attr_payload_as::<u32>(CtrlAttrMcastGrp::Id) {
504                        if mcid == id {
505                            let mcast_name = attributes
506                                .get_attr_payload_as_with_len::<String>(CtrlAttrMcastGrp::Name)?;
507                            res = Ok((name.clone(), mcast_name));
508                        }
509                    }
510                }
511            }
512        }
513
514        res
515    }
516
517    /// Convenience function to send an [`Nlmsghdr`] struct
518    pub fn send<T, P>(&mut self, msg: Nlmsghdr<T, P>) -> Result<(), SerError>
519    where
520        T: NlType + Debug,
521        P: ToBytes + Debug,
522    {
523        debug!("Message sent:\n{:?}", msg);
524
525        if msg.nl_flags.contains(&NlmF::Ack) && !msg.nl_flags.contains(&NlmF::Dump) {
526            self.needs_ack = true;
527        }
528
529        let mut buffer = Cursor::new(Vec::new());
530        msg.to_bytes(&mut buffer)?;
531        self.socket.send(buffer.get_ref(), 0)?;
532
533        Ok(())
534    }
535
536    /// Convenience function to read a stream of
537    /// [`Nlmsghdr`][crate::nl::Nlmsghdr] structs one by one.
538    /// Use [`NlSocketHandle::iter`] instead for easy iteration over
539    /// returned packets.
540    ///
541    /// Returns [`None`] only in non-blocking contexts if no
542    /// message can be immediately returned or if the socket
543    /// has been closed.
544    pub fn recv<'a, T, P>(&'a mut self) -> Result<Option<Nlmsghdr<T, P>>, NlError<T, P>>
545    where
546        T: NlType + Debug,
547        P: FromBytesWithInput<'a, Input = usize> + Debug,
548    {
549        if self.end == self.position {
550            // Read the buffer from the socket and fail if nothing
551            // was read.
552            let mem_read_res = self.socket.recv(&mut self.buffer, 0);
553            if let Err(ref e) = mem_read_res {
554                if e.kind() == io::ErrorKind::WouldBlock {
555                    return Ok(None);
556                }
557            }
558            let mem_read = mem_read_res?;
559            if mem_read == 0 {
560                return Ok(None);
561            }
562            self.position = 0;
563            self.end = mem_read;
564        }
565
566        let (packet_res, next_packet_len) = {
567            let end = self.buffer.len();
568            // Get the next packet length at the current position of the
569            // buffer for the next read operation.
570            if self.position == end {
571                return Ok(None);
572            }
573            let next_packet_len = packet_length_u32(&self.buffer, self.position);
574            // If the packet extends past the end of the number of bytes
575            // read into the buffer, return an error; something
576            // has gone wrong.
577            if self.position + next_packet_len > end {
578                return Err(NlError::new("Incomplete packet received from socket"));
579            }
580
581            // Deserialize the next Nlmsghdr struct.
582            let deserialized_packet_result = Nlmsghdr::<T, P>::from_bytes(&mut Cursor::new(
583                &self.buffer[self.position..self.position + next_packet_len],
584            ));
585
586            (deserialized_packet_result, next_packet_len)
587        };
588
589        let packet = match packet_res {
590            Ok(packet) => {
591                // If successful, forward the position of the buffer
592                // for the next read.
593                self.position += next_packet_len;
594
595                packet
596            }
597            Err(e) => return Err(NlError::De(e)),
598        };
599
600        debug!("Message received: {:?}", packet);
601
602        if let NlPayload::Err(e) = packet.nl_payload {
603            return Err(NlError::<T, P>::from(e));
604        } else if let NlPayload::Ack(_) = packet.nl_payload {
605            if self.needs_ack {
606                self.needs_ack = false;
607            } else {
608                return Err(NlError::new(
609                    "Socket did not expect an ACK but one was received",
610                ));
611            }
612        }
613
614        Ok(Some(packet))
615    }
616
617    /// Parse all [`Nlmsghdr`][crate::nl::Nlmsghdr] structs sent in
618    /// one network packet and return them all in a list.
619    ///
620    /// Failure to parse any packet will cause the entire operation
621    /// to fail. If an error is detected at the application level,
622    /// this method will discard any non-error
623    /// [`Nlmsghdr`][crate::nl::Nlmsghdr] structs and only return the
624    /// error. This method checks for ACKs. For a more granular
625    /// approach, use either [`NlSocketHandle::recv`] or
626    /// [`NlSocketHandle::iter`].
627    pub fn recv_all<'a, T, P>(&'a mut self) -> Result<NlBuffer<T, P>, NlError>
628    where
629        T: NlType + Debug,
630        P: FromBytesWithInput<'a, Input = usize> + Debug,
631    {
632        if self.position == self.end {
633            let mem_read = self.socket.recv(&mut self.buffer, 0)?;
634            if mem_read == 0 {
635                return Err(NlError::new("No data could be read from the socket"));
636            }
637            self.end = mem_read;
638        }
639
640        let vec =
641            NlBuffer::from_bytes_with_input(&mut Cursor::new(&self.buffer[0..self.end]), self.end)?;
642
643        debug!("Messages received: {:?}", vec);
644
645        self.position = 0;
646        self.end = 0;
647        Ok(vec)
648    }
649
650    /// Return an iterator object
651    ///
652    /// The argument `iterate_indefinitely` is documented
653    /// in more detail in [`NlMessageIter`]
654    pub fn iter<'a, T, P>(&'a mut self, iter_indefinitely: bool) -> NlMessageIter<'a, T, P>
655    where
656        T: NlType + Debug,
657        P: FromBytesWithInput<'a, Input = usize> + Debug,
658    {
659        let behavior = if iter_indefinitely {
660            IterationBehavior::IterIndefinitely
661        } else {
662            IterationBehavior::EndMultiOnDone
663        };
664        NlMessageIter::new(self, behavior)
665    }
666}
667
668impl AsRawFd for NlSocketHandle {
669    fn as_raw_fd(&self) -> RawFd {
670        self.socket.as_raw_fd()
671    }
672}
673
674impl IntoRawFd for NlSocketHandle {
675    fn into_raw_fd(self) -> RawFd {
676        self.socket.into_raw_fd()
677    }
678}
679
680impl FromRawFd for NlSocketHandle {
681    unsafe fn from_raw_fd(fd: RawFd) -> Self {
682        NlSocketHandle {
683            socket: NlSocket::from_raw_fd(fd),
684            buffer: vec![0; MAX_NL_LENGTH],
685            end: 0,
686            position: 0,
687            needs_ack: false,
688        }
689    }
690}
691
692#[cfg(all(feature = "async", not(no_std)))]
693pub mod tokio {
694    //! Tokio-specific features for neli
695    //!
696    //! This module contains a struct that wraps [`NlSocket`] for
697    //! async IO.
698    use super::*;
699
700    use std::{
701        pin::Pin,
702        sync::Arc,
703        task::{Context, Poll},
704    };
705
706    use ::tokio::io::{unix::AsyncFd, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
707
708    use crate::{err::DeError, Size};
709
710    macro_rules! ready {
711        ($e:expr $(,)?) => {
712            match $e {
713                ::std::task::Poll::Ready(t) => t,
714                ::std::task::Poll::Pending => return ::std::task::Poll::Pending,
715            }
716        };
717    }
718
719    fn poll_read_priv(
720        async_fd: &AsyncFd<super::NlSocket>,
721        cx: &mut Context,
722        buf: &mut ReadBuf,
723    ) -> Poll<io::Result<usize>> {
724        loop {
725            let mut guard = ready!(async_fd.poll_read_ready(cx))?;
726            match guard.try_io(|fd| {
727                let bytes_read = fd.get_ref().recv(buf.initialized_mut(), 0)?;
728                buf.advance(bytes_read);
729                Ok(bytes_read)
730            }) {
731                Ok(Ok(bytes_read)) => return Poll::Ready(Ok(bytes_read)),
732                Ok(Err(e)) => return Poll::Ready(Err(e)),
733                Err(_) => continue,
734            }
735        }
736    }
737
738    fn poll_write_priv(
739        async_fd: &AsyncFd<super::NlSocket>,
740        cx: &mut Context,
741        buf: &[u8],
742    ) -> Poll<io::Result<usize>> {
743        let mut guard = ready!(async_fd.poll_write_ready(cx))?;
744        guard.clear_ready();
745        let socket = async_fd.get_ref();
746        Poll::Ready(socket.send(buf, 0))
747    }
748
749    /// Tokio-enabled Netlink socket struct
750    pub struct NlSocket {
751        socket: Arc<AsyncFd<super::NlSocket>>,
752    }
753
754    impl NlSocket {
755        /// Set up [`NlSocket`][crate::socket::NlSocket] for use
756        /// with tokio; set to nonblocking state and wrap in polling
757        /// mechanism.
758        pub fn new<S>(s: S) -> io::Result<Self>
759        where
760            S: Into<super::NlSocket>,
761        {
762            let socket = s.into();
763            if socket.is_blocking()? {
764                socket.nonblock()?;
765            }
766            Ok(NlSocket {
767                socket: Arc::new(AsyncFd::new(socket)?),
768            })
769        }
770
771        /// Send a message on the socket asynchronously.
772        pub async fn send<T, P>(&mut self, msg: &Nlmsghdr<T, P>) -> Result<(), SerError>
773        where
774            T: NlType,
775            P: Size + ToBytes,
776        {
777            let mut buffer = Cursor::new(vec![0; msg.padded_size()]);
778            msg.to_bytes(&mut buffer)?;
779            self.write_all(buffer.get_ref()).await?;
780            Ok(())
781        }
782
783        /// Receive a message from the socket asynchronously.
784        pub async fn recv<'a, T, P>(
785            &mut self,
786            buffer: &'a mut Vec<u8>,
787        ) -> Result<NlBuffer<T, P>, DeError>
788        where
789            T: NlType,
790            P: FromBytesWithInput<'a, Input = usize>,
791        {
792            if buffer.len() != MAX_NL_LENGTH {
793                buffer.resize(MAX_NL_LENGTH, 0);
794            }
795            let bytes = self.read(buffer.as_mut_slice()).await?;
796            buffer.truncate(bytes);
797            NlBuffer::from_bytes_with_input(&mut Cursor::new(buffer.as_slice()), bytes)
798        }
799    }
800
801    impl AsyncRead for NlSocket {
802        fn poll_read(
803            self: Pin<&mut Self>,
804            cx: &mut Context,
805            buf: &mut ReadBuf,
806        ) -> Poll<io::Result<()>> {
807            let _ = ready!(poll_read_priv(&self.socket, cx, buf))?;
808            Poll::Ready(Ok(()))
809        }
810    }
811
812    impl AsyncWrite for NlSocket {
813        fn poll_write(
814            self: Pin<&mut Self>,
815            cx: &mut Context,
816            buf: &[u8],
817        ) -> Poll<io::Result<usize>> {
818            poll_write_priv(&self.socket, cx, buf)
819        }
820
821        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
822            Poll::Ready(Ok(()))
823        }
824
825        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
826            Poll::Ready(Ok(()))
827        }
828    }
829
830    impl Unpin for NlSocket {}
831}
832
833#[cfg(test)]
834mod test {
835    use super::*;
836
837    use crate::{consts::nl::Nlmsg, test::setup};
838
839    #[test]
840    fn multi_msg_iter() {
841        setup();
842
843        let mut attrs = GenlBuffer::new();
844        attrs.push(Nlattr::new(false, false, CtrlAttr::FamilyId, 5u32).unwrap());
845        attrs.push(Nlattr::new(false, false, CtrlAttr::FamilyName, "my_family_name").unwrap());
846        let nl1 = Nlmsghdr::new(
847            None,
848            NlTypeWrapper::Nlmsg(Nlmsg::Noop),
849            NlmFFlags::new(&[NlmF::Multi]),
850            None,
851            None,
852            NlPayload::Payload(Genlmsghdr::new(CtrlCmd::Unspec, 2, attrs)),
853        );
854
855        let mut attrs = GenlBuffer::new();
856        attrs.push(Nlattr::new(false, false, CtrlAttr::FamilyId, 6u32).unwrap());
857        attrs
858            .push(Nlattr::new(false, false, CtrlAttr::FamilyName, "my_other_family_name").unwrap());
859        let nl2 = Nlmsghdr::new(
860            None,
861            NlTypeWrapper::Nlmsg(Nlmsg::Noop),
862            NlmFFlags::new(&[NlmF::Multi]),
863            None,
864            None,
865            NlPayload::Payload(Genlmsghdr::new(CtrlCmd::Unspec, 2, attrs)),
866        );
867        let mut v = NlBuffer::new();
868        v.push(nl1);
869        v.push(nl2);
870        let mut buffer = Cursor::new(Vec::new());
871        let bytes = {
872            v.to_bytes(&mut buffer).unwrap();
873            buffer.into_inner()
874        };
875
876        let bytes_len = bytes.len();
877        let mut s = NlSocketHandle {
878            socket: unsafe { NlSocket::from_raw_fd(-1) },
879            buffer: bytes,
880            needs_ack: false,
881            position: 0,
882            end: bytes_len,
883        };
884        let mut iter = s.iter(false);
885        let nl_next1 = if let Some(Ok(nl_next)) = iter.next() {
886            nl_next
887        } else {
888            panic!("Expected message not found");
889        };
890        let nl_next2 = if let Some(Ok(nl_next)) = iter.next() {
891            nl_next
892        } else {
893            panic!("Expected message not found");
894        };
895        let mut nl = NlBuffer::new();
896        nl.push(nl_next1);
897        nl.push(nl_next2);
898        assert_eq!(nl, v);
899    }
900
901    #[test]
902    fn real_test_mcast_groups() {
903        setup();
904
905        let mut sock = NlSocketHandle::new(NlFamily::Generic).unwrap();
906        let notify_id_result = sock.resolve_nl_mcast_group("nlctrl", "notify");
907        let config_id_result = sock.resolve_nl_mcast_group("devlink", "config");
908
909        let ids = match (notify_id_result, config_id_result) {
910            (Ok(ni), Ok(ci)) => {
911                sock.add_mcast_membership(&[ni, ci]).unwrap();
912                vec![ni, ci]
913            }
914            (Ok(ni), Err(NlError::Nlmsgerr(_))) => {
915                sock.add_mcast_membership(&[ni]).unwrap();
916                vec![ni]
917            }
918            (Err(NlError::Nlmsgerr(_)), Ok(ci)) => {
919                sock.add_mcast_membership(&[ci]).unwrap();
920                vec![ci]
921            }
922            (Err(NlError::Nlmsgerr(_)), Err(NlError::Nlmsgerr(_))) => {
923                return;
924            }
925            (Err(e), _) => panic!("Unexpected result from resolve_nl_mcast_group: {:?}", e),
926            (_, Err(e)) => panic!("Unexpected result from resolve_nl_mcast_group: {:?}", e),
927        };
928
929        let groups = sock.list_mcast_membership().unwrap();
930        for id in ids.iter() {
931            assert!(groups.is_set(*id as usize));
932        }
933
934        sock.drop_mcast_membership(ids.as_slice()).unwrap();
935        let groups = sock.list_mcast_membership().unwrap();
936
937        for id in ids.iter() {
938            assert!(!groups.is_set(*id as usize));
939        }
940    }
941
942    #[test]
943    fn real_test_pid() {
944        setup();
945
946        let s = NlSocket::connect(NlFamily::Generic, Some(5555), &[]).unwrap();
947        assert_eq!(s.pid().unwrap(), 5555);
948    }
949}