shadow_rs/host/descriptor/socket/
netlink.rs

1use std::io::{Cursor, ErrorKind, Read, Write};
2use std::net::Ipv4Addr;
3use std::sync::{Arc, Weak};
4
5use atomic_refcell::AtomicRefCell;
6use linux_api::errno::Errno;
7use linux_api::ioctls::IoctlRequest;
8use linux_api::netlink::{ifaddrmsg, ifinfomsg, nlmsghdr};
9use linux_api::rtnetlink::{RTM_GETADDR, RTM_GETLINK, RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR};
10use linux_api::socket::Shutdown;
11use neli::consts::nl::{NlmF, Nlmsg};
12use neli::consts::rtnl::{Arphrd, Ifa, IfaF, Iff, Ifla, RtAddrFamily, RtScope, Rtm};
13use neli::nl::{NlPayload, Nlmsghdr, NlmsghdrBuilder};
14use neli::rtnl::{Ifaddrmsg, IfaddrmsgBuilder, Ifinfomsg, IfinfomsgBuilder, RtattrBuilder};
15use neli::types::{Buffer, RtBuffer};
16use neli::{FromBytes, ToBytes};
17use nix::sys::socket::{MsgFlags, NetlinkAddr};
18use shadow_shim_helper_rs::syscall_types::ForeignPtr;
19
20use crate::core::worker::Worker;
21use crate::cshadow as c;
22use crate::host::descriptor::listener::{StateEventSource, StateListenHandle, StateListenerFilter};
23use crate::host::descriptor::shared_buf::{
24    BufferHandle, BufferSignals, BufferState, ReaderHandle, SharedBuf,
25};
26use crate::host::descriptor::socket::{RecvmsgArgs, RecvmsgReturn, SendmsgArgs, Socket};
27use crate::host::descriptor::{
28    File, FileMode, FileSignals, FileState, FileStatus, OpenFile, SyscallResult,
29};
30use crate::host::memory_manager::MemoryManager;
31use crate::host::network::namespace::NetworkNamespace;
32use crate::host::syscall::io::{IoVec, IoVecReader, IoVecWriter};
33use crate::host::syscall::types::SyscallError;
34use crate::utility::HostTreePointer;
35use crate::utility::callback_queue::CallbackQueue;
36use crate::utility::sockaddr::SockaddrStorage;
37
38// this constant is copied from UNIX_SOCKET_DEFAULT_BUFFER_SIZE
39const NETLINK_SOCKET_DEFAULT_BUFFER_SIZE: u64 = 212_992;
40
41pub struct NetlinkSocket {
42    /// Data and functionality that is general for all states.
43    common: NetlinkSocketCommon,
44    /// State-specific data and functionality.
45    protocol_state: ProtocolState,
46}
47
48impl NetlinkSocket {
49    pub fn new(
50        status: FileStatus,
51        _socket_type: NetlinkSocketType,
52        _family: NetlinkFamily,
53    ) -> Arc<AtomicRefCell<Self>> {
54        Arc::new_cyclic(|weak| {
55            // each socket tracks its own send limit
56            let buffer = SharedBuf::new(usize::MAX);
57            let buffer = Arc::new(AtomicRefCell::new(buffer));
58
59            // Get the IP address of the host
60            let default_ip = Worker::with_active_host(|host| host.default_ip()).unwrap();
61            // All the interface configurations are the same as in the getifaddrs function handler
62            let interfaces = vec![
63                Interface {
64                    address: Ipv4Addr::LOCALHOST,
65                    label: String::from("lo"),
66                    prefix_len: 8,
67                    if_type: Arphrd::Loopback,
68                    mtu: c::CONFIG_MTU,
69                    scope: RtScope::Host,
70                    index: 1,
71                },
72                Interface {
73                    address: default_ip,
74                    label: String::from("eth0"),
75                    prefix_len: 24,
76                    if_type: Arphrd::Ether,
77                    mtu: c::CONFIG_MTU,
78                    scope: RtScope::Universe,
79                    index: 2,
80                },
81            ];
82
83            let mut common = NetlinkSocketCommon {
84                buffer,
85                send_limit: NETLINK_SOCKET_DEFAULT_BUFFER_SIZE,
86                sent_len: 0,
87                event_source: StateEventSource::new(),
88                state: FileState::ACTIVE,
89                status,
90                has_open_file: false,
91                interfaces,
92            };
93            let protocol_state = ProtocolState::new(&mut common, weak);
94            let mut socket = Self {
95                common,
96                protocol_state,
97            };
98            CallbackQueue::queue_and_run_with_legacy(|cb_queue| {
99                socket.refresh_file_state(FileSignals::empty(), cb_queue)
100            });
101
102            AtomicRefCell::new(socket)
103        })
104    }
105
106    pub fn status(&self) -> FileStatus {
107        self.common.status
108    }
109
110    pub fn set_status(&mut self, status: FileStatus) {
111        self.common.status = status;
112    }
113
114    pub fn mode(&self) -> FileMode {
115        FileMode::READ | FileMode::WRITE
116    }
117
118    pub fn has_open_file(&self) -> bool {
119        self.common.has_open_file
120    }
121
122    pub fn supports_sa_restart(&self) -> bool {
123        self.common.supports_sa_restart()
124    }
125
126    pub fn set_has_open_file(&mut self, val: bool) {
127        self.common.has_open_file = val;
128    }
129
130    pub fn getsockname(&self) -> Result<Option<nix::sys::socket::NetlinkAddr>, Errno> {
131        self.protocol_state.bound_address()
132    }
133
134    pub fn getpeername(&self) -> Result<Option<nix::sys::socket::NetlinkAddr>, Errno> {
135        warn_once_then_debug!(
136            "getpeername() syscall not yet supported for netlink sockets; Returning ENOSYS"
137        );
138        Err(Errno::ENOSYS)
139    }
140
141    pub fn address_family(&self) -> linux_api::socket::AddressFamily {
142        linux_api::socket::AddressFamily::AF_NETLINK
143    }
144
145    pub fn close(&mut self, cb_queue: &mut CallbackQueue) -> Result<(), SyscallError> {
146        self.protocol_state.close(&mut self.common, cb_queue)
147    }
148
149    fn refresh_file_state(&mut self, signals: FileSignals, cb_queue: &mut CallbackQueue) {
150        self.protocol_state
151            .refresh_file_state(&mut self.common, signals, cb_queue)
152    }
153
154    pub fn shutdown(
155        &mut self,
156        _how: Shutdown,
157        _cb_queue: &mut CallbackQueue,
158    ) -> Result<(), SyscallError> {
159        warn_once_then_debug!(
160            "shutdown() syscall not yet supported for netlink sockets; Returning ENOSYS"
161        );
162        Err(Errno::ENOSYS.into())
163    }
164
165    pub fn getsockopt(
166        &mut self,
167        _level: libc::c_int,
168        _optname: libc::c_int,
169        _optval_ptr: ForeignPtr<()>,
170        _optlen: libc::socklen_t,
171        _memory_manager: &mut MemoryManager,
172        _cb_queue: &mut CallbackQueue,
173    ) -> Result<libc::socklen_t, SyscallError> {
174        warn_once_then_debug!(
175            "getsockopt() syscall not yet supported for netlink sockets; Returning ENOSYS"
176        );
177        Err(Errno::ENOSYS.into())
178    }
179
180    pub fn setsockopt(
181        &mut self,
182        level: libc::c_int,
183        optname: libc::c_int,
184        optval_ptr: ForeignPtr<()>,
185        optlen: libc::socklen_t,
186        memory_manager: &MemoryManager,
187    ) -> Result<(), SyscallError> {
188        match (level, optname) {
189            (libc::SOL_SOCKET, libc::SO_SNDBUF) => {
190                type OptType = libc::c_int;
191
192                if usize::try_from(optlen).unwrap() < std::mem::size_of::<OptType>() {
193                    return Err(Errno::EINVAL.into());
194                }
195
196                let optval_ptr = optval_ptr.cast::<OptType>();
197                let val: u64 = memory_manager
198                    .read(optval_ptr)?
199                    .try_into()
200                    .or(Err(Errno::EINVAL))?;
201
202                // Linux kernel doubles this value upon setting
203                let val = val * 2;
204                // We want to keep sent_len lower than send_limit
205                let val = std::cmp::max(val, self.common.sent_len);
206                // Copied the following behaviour from setsockopt of LegacyTcpSocket
207                let val = std::cmp::max(val, 4096);
208                let val = std::cmp::min(val, 268435456); // 2^28 = 256 MiB
209
210                self.common.send_limit = val;
211            }
212            (libc::SOL_SOCKET, libc::SO_RCVBUF) => {
213                // We don't care about the receive buffer size because we already limit the send
214                // buffer size and when recvmsg is called we just retrieve the request packet from
215                // the send buffer, process it, and return the response immediately to the caller
216            }
217            _ => {
218                warn_once_then_debug!(
219                    "setsockopt called with unsupported level {level} and opt {optname}"
220                );
221                return Err(Errno::ENOPROTOOPT.into());
222            }
223        }
224
225        Ok(())
226    }
227
228    pub fn bind(
229        socket: &Arc<AtomicRefCell<Self>>,
230        addr: Option<&SockaddrStorage>,
231        _net_ns: &NetworkNamespace,
232        rng: impl rand::Rng,
233    ) -> Result<(), SyscallError> {
234        let socket_ref = &mut *socket.borrow_mut();
235        socket_ref
236            .protocol_state
237            .bind(&mut socket_ref.common, socket, addr, rng)
238    }
239
240    pub fn readv(
241        &mut self,
242        _iovs: &[IoVec],
243        _offset: Option<libc::off_t>,
244        _flags: libc::c_int,
245        _mem: &mut MemoryManager,
246        _cb_queue: &mut CallbackQueue,
247    ) -> Result<libc::ssize_t, SyscallError> {
248        // we could call NetlinkSocket::recvmsg() here, but for now we expect that there are no code
249        // paths that would call NetlinkSocket::readv() since the readv() syscall handler should have
250        // called NetlinkSocket::recvmsg() instead
251        panic!("Called NetlinkSocket::readv() on a netlink socket.");
252    }
253
254    pub fn writev(
255        &mut self,
256        _iovs: &[IoVec],
257        _offset: Option<libc::off_t>,
258        _flags: libc::c_int,
259        _mem: &mut MemoryManager,
260        _cb_queue: &mut CallbackQueue,
261    ) -> Result<libc::ssize_t, SyscallError> {
262        // we could call NetlinkSocket::sendmsg() here, but for now we expect that there are no code
263        // paths that would call NetlinkSocket::writev() since the writev() syscall handler should have
264        // called NetlinkSocket::sendmsg() instead
265        panic!("Called NetlinkSocket::writev() on a netlink socket");
266    }
267
268    pub fn sendmsg(
269        socket: &Arc<AtomicRefCell<Self>>,
270        args: SendmsgArgs,
271        mem: &mut MemoryManager,
272        _net_ns: &NetworkNamespace,
273        _rng: impl rand::Rng,
274        cb_queue: &mut CallbackQueue,
275    ) -> Result<libc::ssize_t, SyscallError> {
276        let socket_ref = &mut *socket.borrow_mut();
277        socket_ref
278            .protocol_state
279            .sendmsg(&mut socket_ref.common, socket, args, mem, cb_queue)
280    }
281
282    pub fn recvmsg(
283        socket: &Arc<AtomicRefCell<Self>>,
284        args: RecvmsgArgs,
285        mem: &mut MemoryManager,
286        cb_queue: &mut CallbackQueue,
287    ) -> Result<RecvmsgReturn, SyscallError> {
288        let socket_ref = &mut *socket.borrow_mut();
289        socket_ref
290            .protocol_state
291            .recvmsg(&mut socket_ref.common, socket, args, mem, cb_queue)
292    }
293
294    pub fn listen(
295        _socket: &Arc<AtomicRefCell<Self>>,
296        _backlog: i32,
297        _net_ns: &NetworkNamespace,
298        _rng: impl rand::Rng,
299        _cb_queue: &mut CallbackQueue,
300    ) -> Result<(), Errno> {
301        warn_once_then_debug!("We do not yet handle listen request on netlink sockets");
302        Err(Errno::EINVAL)
303    }
304
305    pub fn connect(
306        _socket: &Arc<AtomicRefCell<Self>>,
307        _addr: &SockaddrStorage,
308        _net_ns: &NetworkNamespace,
309        _rng: impl rand::Rng,
310        _cb_queue: &mut CallbackQueue,
311    ) -> Result<(), SyscallError> {
312        warn_once_then_debug!("We do not yet handle connect request on netlink sockets");
313        Err(Errno::EINVAL.into())
314    }
315
316    pub fn accept(
317        &mut self,
318        _net_ns: &NetworkNamespace,
319        _rng: impl rand::Rng,
320        _cb_queue: &mut CallbackQueue,
321    ) -> Result<OpenFile, SyscallError> {
322        warn_once_then_debug!("We do not yet handle accept request on netlink sockets");
323        Err(Errno::EINVAL.into())
324    }
325
326    pub fn ioctl(
327        &mut self,
328        request: IoctlRequest,
329        _arg_ptr: ForeignPtr<()>,
330        _memory_manager: &mut MemoryManager,
331    ) -> SyscallResult {
332        warn_once_then_debug!("We do not yet handle ioctl request {request:?} on netlink sockets");
333        Err(Errno::EINVAL.into())
334    }
335
336    pub fn stat(&self) -> Result<linux_api::stat::stat, SyscallError> {
337        warn_once_then_debug!("We do not yet handle stat calls on netlink sockets");
338        Err(Errno::EINVAL.into())
339    }
340
341    pub fn add_listener(
342        &mut self,
343        monitoring_state: FileState,
344        monitoring_signals: FileSignals,
345        filter: StateListenerFilter,
346        notify_fn: impl Fn(FileState, FileState, FileSignals, &mut CallbackQueue)
347        + Send
348        + Sync
349        + 'static,
350    ) -> StateListenHandle {
351        self.common.event_source.add_listener(
352            monitoring_state,
353            monitoring_signals,
354            filter,
355            notify_fn,
356        )
357    }
358
359    pub fn add_legacy_listener(&mut self, ptr: HostTreePointer<c::StatusListener>) {
360        self.common.event_source.add_legacy_listener(ptr);
361    }
362
363    pub fn remove_legacy_listener(&mut self, ptr: *mut c::StatusListener) {
364        self.common.event_source.remove_legacy_listener(ptr);
365    }
366
367    pub fn state(&self) -> FileState {
368        self.common.state
369    }
370}
371
372struct InitialState {
373    bound_addr: Option<NetlinkAddr>,
374    reader_handle: ReaderHandle,
375    // this handle is never accessed, but we store it because of its drop impl
376    _buffer_handle: BufferHandle,
377}
378struct ClosedState {}
379/// The current protocol state of the netlink socket. An `Option` is required for each variant so that
380/// the inner state object can be removed, transformed into a new state, and then re-added as a
381/// different variant.
382enum ProtocolState {
383    Initial(Option<InitialState>),
384    Closed(Option<ClosedState>),
385}
386
387/// Upcast from a type to an enum variant.
388macro_rules! state_upcast {
389    ($type:ty, $parent:ident::$variant:ident) => {
390        impl From<$type> for $parent {
391            fn from(x: $type) -> Self {
392                Self::$variant(Some(x))
393            }
394        }
395    };
396}
397
398// implement upcasting for all state types
399state_upcast!(InitialState, ProtocolState::Initial);
400state_upcast!(ClosedState, ProtocolState::Closed);
401
402impl ProtocolState {
403    fn new(common: &mut NetlinkSocketCommon, socket: &Weak<AtomicRefCell<NetlinkSocket>>) -> Self {
404        // this is a new socket and there are no listeners, so safe to use a temporary event queue
405        let mut cb_queue = CallbackQueue::new();
406
407        // increment the buffer's reader count
408        let reader_handle = common.buffer.borrow_mut().add_reader(&mut cb_queue);
409
410        let weak = Weak::clone(socket);
411        let buffer_handle = common.buffer.borrow_mut().add_listener(
412            BufferState::READABLE,
413            BufferSignals::BUFFER_GREW,
414            move |_, signals, cb_queue| {
415                if let Some(socket) = weak.upgrade() {
416                    let signals = if signals.contains(BufferSignals::BUFFER_GREW) {
417                        FileSignals::READ_BUFFER_GREW
418                    } else {
419                        FileSignals::empty()
420                    };
421
422                    socket.borrow_mut().refresh_file_state(signals, cb_queue);
423                }
424            },
425        );
426
427        ProtocolState::Initial(Some(InitialState {
428            bound_addr: None,
429            reader_handle,
430            _buffer_handle: buffer_handle,
431        }))
432    }
433
434    fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
435        match self {
436            Self::Initial(x) => x.as_ref().unwrap().bound_address(),
437            Self::Closed(x) => x.as_ref().unwrap().bound_address(),
438        }
439    }
440
441    fn refresh_file_state(
442        &self,
443        common: &mut NetlinkSocketCommon,
444        signals: FileSignals,
445        cb_queue: &mut CallbackQueue,
446    ) {
447        match self {
448            Self::Initial(x) => x
449                .as_ref()
450                .unwrap()
451                .refresh_file_state(common, signals, cb_queue),
452            Self::Closed(x) => x
453                .as_ref()
454                .unwrap()
455                .refresh_file_state(common, signals, cb_queue),
456        }
457    }
458
459    fn close(
460        &mut self,
461        common: &mut NetlinkSocketCommon,
462        cb_queue: &mut CallbackQueue,
463    ) -> Result<(), SyscallError> {
464        let (new_state, rv) = match self {
465            Self::Initial(x) => x.take().unwrap().close(common, cb_queue),
466            Self::Closed(x) => x.take().unwrap().close(common, cb_queue),
467        };
468
469        *self = new_state;
470        rv
471    }
472
473    fn bind(
474        &mut self,
475        common: &mut NetlinkSocketCommon,
476        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
477        addr: Option<&SockaddrStorage>,
478        rng: impl rand::Rng,
479    ) -> Result<(), SyscallError> {
480        match self {
481            Self::Initial(x) => x.as_mut().unwrap().bind(common, socket, addr, rng),
482            Self::Closed(x) => x.as_mut().unwrap().bind(common, socket, addr, rng),
483        }
484    }
485
486    fn sendmsg(
487        &mut self,
488        common: &mut NetlinkSocketCommon,
489        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
490        args: SendmsgArgs,
491        mem: &mut MemoryManager,
492        cb_queue: &mut CallbackQueue,
493    ) -> Result<libc::ssize_t, SyscallError> {
494        match self {
495            Self::Initial(x) => x
496                .as_mut()
497                .unwrap()
498                .sendmsg(common, socket, args, mem, cb_queue),
499            Self::Closed(x) => x
500                .as_mut()
501                .unwrap()
502                .sendmsg(common, socket, args, mem, cb_queue),
503        }
504    }
505
506    fn recvmsg(
507        &mut self,
508        common: &mut NetlinkSocketCommon,
509        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
510        args: RecvmsgArgs,
511        mem: &mut MemoryManager,
512        cb_queue: &mut CallbackQueue,
513    ) -> Result<RecvmsgReturn, SyscallError> {
514        match self {
515            Self::Initial(x) => x
516                .as_mut()
517                .unwrap()
518                .recvmsg(common, socket, args, mem, cb_queue),
519            Self::Closed(x) => x
520                .as_mut()
521                .unwrap()
522                .recvmsg(common, socket, args, mem, cb_queue),
523        }
524    }
525}
526
527impl InitialState {
528    fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
529        Ok(self.bound_addr)
530    }
531
532    fn refresh_file_state(
533        &self,
534        common: &mut NetlinkSocketCommon,
535        signals: FileSignals,
536        cb_queue: &mut CallbackQueue,
537    ) {
538        let mut new_state = FileState::ACTIVE;
539
540        {
541            let buffer = common.buffer.borrow();
542
543            new_state.set(FileState::READABLE, buffer.has_data());
544            new_state.set(FileState::WRITABLE, common.sent_len < common.send_limit);
545        }
546
547        common.update_state(
548            /* mask= */ FileState::all(),
549            new_state,
550            signals,
551            cb_queue,
552        );
553    }
554
555    fn close(
556        self,
557        common: &mut NetlinkSocketCommon,
558        cb_queue: &mut CallbackQueue,
559    ) -> (ProtocolState, Result<(), SyscallError>) {
560        // inform the buffer that there is one fewer readers
561        common
562            .buffer
563            .borrow_mut()
564            .remove_reader(self.reader_handle, cb_queue);
565
566        let new_state = ClosedState {};
567        new_state.refresh_file_state(common, FileSignals::empty(), cb_queue);
568        (new_state.into(), Ok(()))
569    }
570
571    fn bind(
572        &mut self,
573        _common: &mut NetlinkSocketCommon,
574        _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
575        addr: Option<&SockaddrStorage>,
576        _rng: impl rand::Rng,
577    ) -> Result<(), SyscallError> {
578        // if already bound
579        if self.bound_addr.is_some() {
580            return Err(Errno::EINVAL.into());
581        }
582        // if the bound address is null
583        if addr.is_none() {
584            return Err(Errno::EFAULT.into());
585        }
586
587        // get the netlink address
588        let Some(addr) = addr.and_then(|x| x.as_netlink()) else {
589            log::warn!("Attempted to bind netlink socket to non-netlink address {addr:?}");
590            return Err(Errno::EINVAL.into());
591        };
592
593        // TODO: According to netlink(7), if the pid is zero, the kernel takes care of assigning
594        // it, but we will leave it untouched at the moment. We can implement the assignment
595        // later when we want to support it.
596        self.bound_addr = Some(*addr);
597
598        // According to netlink(7), if the groups is non-zero, it means that the socket wants to
599        // listen to some groups. If it includes unsupported groups, we will emit the error here.
600        if (addr.groups() & !(RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR)) != 0 {
601            log::warn!(
602                "Attempted to bind netlink socket to an address with unsupported groups {}",
603                addr.groups()
604            );
605            return Err(Errno::EINVAL.into());
606        }
607
608        Ok(())
609    }
610
611    fn sendmsg(
612        &mut self,
613        common: &mut NetlinkSocketCommon,
614        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
615        args: SendmsgArgs,
616        mem: &mut MemoryManager,
617        cb_queue: &mut CallbackQueue,
618    ) -> Result<libc::ssize_t, SyscallError> {
619        if !args.control_ptr.ptr().is_null() {
620            log::debug!("Netlink sockets don't yet support control data for sendmsg()");
621            return Err(Errno::EINVAL.into());
622        }
623
624        // It's okay to not have a destination address
625        if let Some(addr) = args.addr {
626            // Parse the address
627            let Some(addr) = addr.as_netlink() else {
628                log::warn!("Attempted to send to non-netlink address {:?}", args.addr);
629                return Err(Errno::EINVAL.into());
630            };
631            // Sending to non-kernel address is not supported
632            if addr.pid() != 0 {
633                log::warn!("Attempted to send to non-kernel netlink address {addr:?}");
634                return Err(Errno::EINVAL.into());
635            }
636            // Sending to groups is not supported
637            if addr.groups() != 0 {
638                log::warn!("Attempted to send to netlink groups {addr:?}");
639                return Err(Errno::EINVAL.into());
640            }
641        }
642
643        let rv = common.sendmsg(socket, args.iovs, args.flags, mem, cb_queue)?;
644
645        self.refresh_file_state(common, FileSignals::empty(), cb_queue);
646
647        Ok(rv.try_into().unwrap())
648    }
649
650    fn recvmsg(
651        &mut self,
652        common: &mut NetlinkSocketCommon,
653        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
654        args: RecvmsgArgs,
655        mem: &mut MemoryManager,
656        cb_queue: &mut CallbackQueue,
657    ) -> Result<RecvmsgReturn, SyscallError> {
658        if !args.control_ptr.ptr().is_null() {
659            log::debug!("Netlink sockets don't yet support control data for recvmsg()");
660            return Err(Errno::EINVAL.into());
661        }
662        let Some(flags) = MsgFlags::from_bits(args.flags) else {
663            warn_once_then_debug!("Unrecognized recv flags: {:#b}", args.flags);
664            return Err(Errno::EINVAL.into());
665        };
666
667        let mut packet_buffer = Vec::new();
668        let (_rv, _num_removed_from_buf) =
669            common.recvmsg(socket, &mut packet_buffer, flags, mem, cb_queue)?;
670        self.refresh_file_state(common, FileSignals::empty(), cb_queue);
671
672        let mut writer = IoVecWriter::new(args.iovs, mem);
673
674        // We set the source address as the netlink address of the kernel
675        let src_addr = SockaddrStorage::from_netlink(&NetlinkAddr::new(0, 0));
676
677        if packet_buffer.len() < std::mem::size_of::<nlmsghdr>() {
678            log::warn!("The processed packet is too short");
679            return Err(Errno::EINVAL.into());
680        }
681
682        let buffer = {
683            let nlmsg_type = &packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_type)];
684            let nlmsg_type = u16::from_ne_bytes(nlmsg_type.try_into().unwrap());
685
686            match nlmsg_type {
687                RTM_GETLINK => {
688                    let nlmsghdr_len = std::mem::size_of::<nlmsghdr>();
689                    let ifinfomsg_len = std::mem::size_of::<ifinfomsg>();
690                    let header_len = nlmsghdr_len + ifinfomsg_len;
691
692                    // Pad the message if it's too short and update the len field
693                    //
694                    // We typically try not to zero-fill structs when the bytes are missing,
695                    // but it should be okay here since we don't yet support most of the fields
696                    // of ifinfomsg
697                    if (nlmsghdr_len..header_len).contains(&packet_buffer.len()) {
698                        log::debug!(
699                            "Padding the RTM_GETLINK with zeroes to meet the minimum length"
700                        );
701                        packet_buffer.resize(header_len, 0);
702                        packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_len)]
703                            .copy_from_slice(&(header_len as u32).to_ne_bytes()[..]);
704                    }
705                    self.handle_ifinfomsg(common, &packet_buffer[..])
706                }
707                RTM_GETADDR => {
708                    let nlmsghdr_len = std::mem::size_of::<nlmsghdr>();
709                    let ifaddrmsg_len = std::mem::size_of::<ifaddrmsg>();
710                    let header_len = nlmsghdr_len + ifaddrmsg_len;
711
712                    // Pad the message if it's too short and update the len field
713                    //
714                    // We typically try not to zero-fill structs when the bytes are missing,
715                    // but it should be okay here since we don't yet support most of the fields
716                    // of ifaddrmsg
717                    if (nlmsghdr_len..header_len).contains(&packet_buffer.len()) {
718                        log::debug!(
719                            "Padding the RTM_GETADDR with zeroes to meet the minimum length"
720                        );
721                        packet_buffer.resize(header_len, 0);
722                        packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_len)]
723                            .copy_from_slice(&(header_len as u32).to_ne_bytes()[..]);
724                    }
725                    self.handle_ifaddrmsg(common, &packet_buffer[..])
726                }
727                _ => {
728                    warn_once_then_debug!(
729                        "Found unsupported nlmsg_type: {nlmsg_type} (only RTM_GETLINK
730                        and RTM_GETADDR are supported)"
731                    );
732                    self.handle_error(&packet_buffer[..])
733                }
734            }
735        };
736
737        // Try to write as much as we can. If the buffer is too small, just discard the rest
738        let mut total_copied = 0;
739        let mut buf = buffer.as_slice();
740        while !buf.is_empty() {
741            match writer.write(buf) {
742                Ok(0) => break,
743                Ok(n) => {
744                    buf = &buf[n..];
745                    total_copied += n;
746                }
747                Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
748                Err(e) => return Err(e.into()),
749            }
750        }
751
752        let return_val = if flags.contains(MsgFlags::MSG_TRUNC) {
753            buffer.len()
754        } else {
755            total_copied
756        };
757
758        Ok(RecvmsgReturn {
759            return_val: return_val.try_into().unwrap(),
760            addr: Some(src_addr),
761            msg_flags: 0,
762            control_len: 0,
763        })
764    }
765
766    fn handle_error(&self, bytes: &[u8]) -> Vec<u8> {
767        // If we can't get the pid, set it to zero
768        let nlmsg_seq = match bytes.get(memoffset::span_of!(nlmsghdr, nlmsg_seq)) {
769            Some(x) => u32::from_ne_bytes(x.try_into().unwrap()),
770            None => 0,
771        };
772
773        // Generate a dummy error with the same sequence number as the request
774        let msg = NlmsghdrBuilder::default()
775            .nl_type(Nlmsg::Error)
776            .nl_flags(NlmF::empty())
777            .nl_seq(nlmsg_seq)
778            .nl_payload(NlPayload::<Nlmsg, ()>::Empty)
779            .build()
780            .expect("NlmsghdrBuilder missing a required field");
781
782        let mut buffer = Cursor::new(Vec::new());
783        msg.to_bytes(&mut buffer).unwrap();
784        buffer.into_inner()
785    }
786
787    fn handle_ifaddrmsg(&self, common: &mut NetlinkSocketCommon, bytes: &[u8]) -> Vec<u8> {
788        let Ok(nlmsg) = Nlmsghdr::<Rtm, Ifaddrmsg>::from_bytes(&mut Cursor::new(bytes)) else {
789            log::warn!("Failed to deserialize the message");
790            return self.handle_error(bytes);
791        };
792
793        let Some(ifaddrmsg) = nlmsg.get_payload() else {
794            log::warn!("Failed to find the payload");
795            return self.handle_error(bytes);
796        };
797
798        // The only supported interface address family is AF_INET
799        if *ifaddrmsg.ifa_family() != RtAddrFamily::Unspecified
800            && *ifaddrmsg.ifa_family() != RtAddrFamily::Inet
801        {
802            log::warn!("Unsupported ifa_family (only AF_UNSPEC and AF_INET are supported)");
803            return self.handle_error(bytes);
804        }
805
806        // The rest of the fields are unsupported. We limit only the interest to the zero values
807        if *ifaddrmsg.ifa_prefixlen() != 0
808            || !ifaddrmsg.ifa_flags().is_empty()
809            || *ifaddrmsg.ifa_index() != 0
810            || *ifaddrmsg.ifa_scope() != RtScope::Universe
811        {
812            log::warn!(
813                "Unsupported ifa_prefixlen, ifa_flags, ifa_scope, or ifa_index (they have to be 0)",
814            );
815            return self.handle_error(bytes);
816        }
817
818        let mut buffer = Cursor::new(Vec::new());
819        // Send the interface addresses
820        for interface in &common.interfaces {
821            let address = interface.address.octets();
822            let broadcast = Ipv4Addr::from(
823                0xffff_ffff_u32
824                    .checked_shr(u32::from(interface.prefix_len))
825                    .unwrap_or(0)
826                    | u32::from(interface.address),
827            )
828            .octets();
829            let mut label = Vec::from(interface.label.as_bytes());
830            label.push(0); // Null-terminate
831
832            // List of attribtes sent with the response for the current interface
833            let attrs = [
834                // I don't know the difference between IFA_ADDRESS and IFA_LOCAL. However, Linux
835                // provides the same address for both attributes, so I do the same.
836                // Run `strace ip addr` to see.
837                RtattrBuilder::default()
838                    .rta_type(Ifa::Address)
839                    .rta_payload(Buffer::from(&address[..]))
840                    .build()
841                    .unwrap(),
842                RtattrBuilder::default()
843                    .rta_type(Ifa::Local)
844                    .rta_payload(Buffer::from(&address[..]))
845                    .build()
846                    .unwrap(),
847                RtattrBuilder::default()
848                    .rta_type(Ifa::Broadcast)
849                    .rta_payload(Buffer::from(&broadcast[..]))
850                    .build()
851                    .unwrap(),
852                RtattrBuilder::default()
853                    .rta_type(Ifa::Label)
854                    .rta_payload(Buffer::from(label))
855                    .build()
856                    .unwrap(),
857            ];
858            let ifaddrmsg = IfaddrmsgBuilder::default()
859                .ifa_family(RtAddrFamily::Inet)
860                .ifa_prefixlen(interface.prefix_len)
861                // IFA_F_PERMANENT is used to indicate that the address is permanent
862                .ifa_flags(IfaF::PERMANENT)
863                .ifa_scope(interface.scope)
864                .ifa_index(interface.index)
865                .rtattrs(RtBuffer::from_iter(attrs))
866                .build()
867                .expect("IfaddrmsgBuilder missing a required field");
868            let nlmsg = NlmsghdrBuilder::default()
869                .nl_type(Rtm::Newaddr)
870                // The NLM_F_MULTI flag is used to indicate that we will send multiple messages
871                .nl_flags(NlmF::MULTI)
872                // Use the same sequence number as the request
873                .nl_seq(*nlmsg.nl_seq())
874                .nl_payload(NlPayload::Payload(ifaddrmsg))
875                .build()
876                .expect("NlmsghdrBuilder missing a required field");
877            nlmsg.to_bytes(&mut buffer).unwrap();
878        }
879        // After sending the messages with the NLM_F_MULTI flag set, we need to send the NLMSG_DONE message
880        let done_msg = NlmsghdrBuilder::default()
881            .nl_type(Nlmsg::Done)
882            .nl_flags(NlmF::MULTI)
883            // Use the same sequence number as the request
884            .nl_seq(*nlmsg.nl_seq())
885            // Linux also emits the errno of zero after the header. See `strace ip addr`.
886            // For documentation reference, see:
887            // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/Documentation/userspace-api/netlink/intro.rst?h=v6.2#n232
888            // For code reference, see:
889            // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/netlink/af_netlink.c?h=v6.2#n2222
890            .nl_payload(NlPayload::Payload(0u32))
891            .build()
892            .expect("NlmsghdrBuilder missing a required field");
893        done_msg.to_bytes(&mut buffer).unwrap();
894
895        buffer.into_inner()
896    }
897
898    fn handle_ifinfomsg(&self, common: &mut NetlinkSocketCommon, bytes: &[u8]) -> Vec<u8> {
899        let Ok(nlmsg) = Nlmsghdr::<Rtm, Ifinfomsg>::from_bytes(&mut Cursor::new(bytes)) else {
900            log::warn!("Failed to deserialize the message");
901            return self.handle_error(bytes);
902        };
903
904        let Some(ifinfomsg) = nlmsg.get_payload() else {
905            log::warn!("Failed to find the payload");
906            return self.handle_error(bytes);
907        };
908
909        // The only supported interface address family is AF_INET
910        if *ifinfomsg.ifi_family() != RtAddrFamily::Unspecified
911            && *ifinfomsg.ifi_family() != RtAddrFamily::Inet
912        {
913            warn_once_then_debug!(
914                "Unsupported ifi_family (only AF_UNSPEC and AF_INET are supported)"
915            );
916            return self.handle_error(bytes);
917        }
918
919        // The rest of the fields are unsupported. We limit only the interest to the zero values
920        if *ifinfomsg.ifi_type() != 0.into()
921            || *ifinfomsg.ifi_index() != 0
922            || !ifinfomsg.ifi_flags().is_empty()
923        {
924            warn_once_then_debug!(
925                "Unsupported ifi_type, ifi_index, or ifi_flags (they have to be 0)"
926            );
927            return self.handle_error(bytes);
928        }
929
930        // We don't check for ifi_change because we found that `ip addr` sets it to zero even if
931        // rtnetlink(7) recommends to set it to all 1s
932
933        let mut buffer = Cursor::new(Vec::new());
934        // Send the interface addresses
935        for interface in &common.interfaces {
936            let mut label = Vec::from(interface.label.as_bytes());
937            label.push(0); // Null-terminate
938
939            // List of attribtes sent with the response for the current interface
940            let attrs = [
941                RtattrBuilder::default()
942                    .rta_type(Ifla::Ifname)
943                    .rta_payload(Buffer::from(label))
944                    .build()
945                    .unwrap(),
946                // Not sure about the value of this one, but I always see 1000 from `ip addr`. If
947                // we don't specify this, `ip addr` will create an AF_INET socket and do ioctl. See
948                // https://git.kernel.org/pub/scm/network/iproute2/iproute2.git/tree/ip/ipaddress.c#n168
949                RtattrBuilder::default()
950                    .rta_type(Ifla::Txqlen)
951                    .rta_payload(Buffer::from(&u32::to_le_bytes(1000)[..]))
952                    .build()
953                    .unwrap(),
954                RtattrBuilder::default()
955                    .rta_type(Ifla::Mtu)
956                    .rta_payload(Buffer::from(&u32::to_le_bytes(interface.mtu)[..]))
957                    .build()
958                    .unwrap(),
959                // TODO: Add the MAC address through IFLA_ADDRESS and IFLA_BROADCAST
960            ];
961            let flags = if interface.if_type == Arphrd::Loopback {
962                Iff::UP | Iff::LOOPBACK | Iff::RUNNING
963            } else {
964                // Not sure about the IFF_MULTICAST, but it's also the one I got from `strace ip addr`
965                Iff::UP | Iff::BROADCAST | Iff::RUNNING | Iff::MULTICAST
966            };
967            let interface_index = interface
968                .index
969                .try_into()
970                .expect("interface index too large");
971
972            let ifinfomsg = IfinfomsgBuilder::default()
973                .ifi_family(RtAddrFamily::Inet)
974                .ifi_type(interface.if_type)
975                .ifi_index(interface_index)
976                .ifi_flags(flags)
977                // rtnetlink(7) recommends to set it to all 1s
978                .ifi_change(Iff::from_bits_retain(0xffffffff))
979                .rtattrs(RtBuffer::from_iter(attrs))
980                .build()
981                .expect("IfinfomsgBuilder missing a required field");
982            let nlmsg = NlmsghdrBuilder::default()
983                .nl_type(Rtm::Newlink)
984                // The NLM_F_MULTI flag is used to indicate that we will send multiple messages
985                .nl_flags(NlmF::MULTI)
986                // Use the same sequence number as the request
987                .nl_seq(*nlmsg.nl_seq())
988                .nl_payload(NlPayload::Payload(ifinfomsg))
989                .build()
990                .expect("NlmsghdrBuilder missing a required field");
991            nlmsg.to_bytes(&mut buffer).unwrap();
992        }
993        // After sending the messages with the NLM_F_MULTI flag set, we need to send the NLMSG_DONE message
994        let done_msg = NlmsghdrBuilder::default()
995            .nl_type(Nlmsg::Done)
996            .nl_flags(NlmF::MULTI)
997            // Use the same sequence number as the request
998            .nl_seq(*nlmsg.nl_seq())
999            // Linux also emits the errno of zero after the header. See `strace ip addr`.
1000            // For documentation reference, see:
1001            // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/Documentation/userspace-api/netlink/intro.rst?h=v6.2#n232
1002            // For code reference, see:
1003            // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/netlink/af_netlink.c?h=v6.2#n2222
1004            .nl_payload(NlPayload::Payload(0u32))
1005            .build()
1006            .expect("NlmsghdrBuilder missing a required field");
1007        done_msg.to_bytes(&mut buffer).unwrap();
1008
1009        buffer.into_inner()
1010    }
1011}
1012
1013impl ClosedState {
1014    fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
1015        Ok(None)
1016    }
1017
1018    fn refresh_file_state(
1019        &self,
1020        common: &mut NetlinkSocketCommon,
1021        signals: FileSignals,
1022        cb_queue: &mut CallbackQueue,
1023    ) {
1024        common.update_state(
1025            /* mask= */ FileState::all(),
1026            FileState::CLOSED,
1027            signals,
1028            cb_queue,
1029        );
1030    }
1031
1032    fn close(
1033        self,
1034        _common: &mut NetlinkSocketCommon,
1035        _cb_queue: &mut CallbackQueue,
1036    ) -> (ProtocolState, Result<(), SyscallError>) {
1037        // why are we trying to close an already closed file? we probably want a bt here...
1038        panic!("Trying to close an already closed socket");
1039    }
1040
1041    fn bind(
1042        &mut self,
1043        _common: &mut NetlinkSocketCommon,
1044        _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1045        _addr: Option<&SockaddrStorage>,
1046        _rng: impl rand::Rng,
1047    ) -> Result<(), SyscallError> {
1048        // We follow the same approach as UnixSocket
1049        log::warn!("bind() while in state {}", std::any::type_name::<Self>());
1050        Err(Errno::EOPNOTSUPP.into())
1051    }
1052
1053    fn sendmsg(
1054        &mut self,
1055        _common: &mut NetlinkSocketCommon,
1056        _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1057        _args: SendmsgArgs,
1058        _mem: &mut MemoryManager,
1059        _cb_queue: &mut CallbackQueue,
1060    ) -> Result<libc::ssize_t, SyscallError> {
1061        // We follow the same approach as UnixSocket
1062        log::warn!("sendmsg() while in state {}", std::any::type_name::<Self>());
1063        Err(Errno::EOPNOTSUPP.into())
1064    }
1065
1066    fn recvmsg(
1067        &mut self,
1068        _common: &mut NetlinkSocketCommon,
1069        _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1070        _args: RecvmsgArgs,
1071        _mem: &mut MemoryManager,
1072        _cb_queue: &mut CallbackQueue,
1073    ) -> Result<RecvmsgReturn, SyscallError> {
1074        // We follow the same approach as UnixSocket
1075        log::warn!("recvmsg() while in state {}", std::any::type_name::<Self>());
1076        Err(Errno::EOPNOTSUPP.into())
1077    }
1078}
1079
1080// The struct used to describe the network interface
1081struct Interface {
1082    address: Ipv4Addr,
1083    label: String,
1084    prefix_len: u8,
1085    if_type: Arphrd,
1086    mtu: u32,
1087    scope: RtScope,
1088    index: libc::c_uint,
1089}
1090
1091/// Common data and functionality that is useful for all states.
1092struct NetlinkSocketCommon {
1093    buffer: Arc<AtomicRefCell<SharedBuf>>,
1094    /// The max number of "in flight" bytes (sent but not yet read from the receiving socket).
1095    send_limit: u64,
1096    /// The number of "in flight" bytes.
1097    sent_len: u64,
1098    event_source: StateEventSource,
1099    state: FileState,
1100    status: FileStatus,
1101    // should only be used by `OpenFile` to make sure there is only ever one `OpenFile` instance for
1102    // this file
1103    has_open_file: bool,
1104    /// Interfaces
1105    interfaces: Vec<Interface>,
1106}
1107
1108impl NetlinkSocketCommon {
1109    pub fn supports_sa_restart(&self) -> bool {
1110        true
1111    }
1112
1113    pub fn sendmsg(
1114        &mut self,
1115        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1116        iovs: &[IoVec],
1117        flags: libc::c_int,
1118        mem: &mut MemoryManager,
1119        cb_queue: &mut CallbackQueue,
1120    ) -> Result<usize, SyscallError> {
1121        // MSG_NOSIGNAL is a no-op, since netlink sockets are not stream-oriented.
1122        // Ignore the MSG_TRUNC flag since it doesn't do anything when sending.
1123        let supported_flags = MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_NOSIGNAL | MsgFlags::MSG_TRUNC;
1124
1125        // if there's a flag we don't support, it's probably best to raise an error rather than do
1126        // the wrong thing
1127        let Some(mut flags) = MsgFlags::from_bits(flags) else {
1128            warn_once_then_debug!("Unrecognized send flags: {:#b}", flags);
1129            return Err(Errno::EINVAL.into());
1130        };
1131        if flags.intersects(!supported_flags) {
1132            warn_once_then_debug!("Unsupported send flags: {:?}", flags);
1133            return Err(Errno::EINVAL.into());
1134        }
1135
1136        if self.status.contains(FileStatus::NONBLOCK) {
1137            flags.insert(MsgFlags::MSG_DONTWAIT);
1138        }
1139
1140        // run in a closure so that an early return doesn't return from the syscall handler
1141        let result = (|| {
1142            let len = iovs.iter().map(|x| x.len).sum::<libc::size_t>();
1143
1144            // we keep track of the send buffer size manually, since the netlink socket buffers all
1145            // have usize::MAX length
1146            let space_available = self
1147                .send_limit
1148                .saturating_sub(self.sent_len)
1149                .try_into()
1150                .unwrap();
1151
1152            if space_available == 0 {
1153                return Err(Errno::EAGAIN);
1154            }
1155
1156            if len > space_available {
1157                if len <= self.send_limit.try_into().unwrap() {
1158                    // we can send this when the buffer has more space available
1159                    return Err(Errno::EAGAIN);
1160                } else {
1161                    // we could never send this message
1162                    return Err(Errno::EMSGSIZE);
1163                }
1164            }
1165
1166            let reader = IoVecReader::new(iovs, mem);
1167            let reader = reader.take(len.try_into().unwrap());
1168
1169            // send the packet directly to the buffer of the socket so that it will be
1170            // processed when the socket is read.
1171            self.buffer
1172                .borrow_mut()
1173                .write_packet(reader, len, cb_queue)
1174                .map_err(|e| Errno::try_from(e).unwrap())?;
1175
1176            // if we successfully sent bytes, update the sent count
1177            self.sent_len += u64::try_from(len).unwrap();
1178            Ok(len)
1179        })();
1180
1181        // if the syscall would block and we don't have the MSG_DONTWAIT flag
1182        if result.as_ref().err() == Some(&Errno::EWOULDBLOCK)
1183            && !flags.contains(MsgFlags::MSG_DONTWAIT)
1184        {
1185            return Err(SyscallError::new_blocked_on_file(
1186                File::Socket(Socket::Netlink(socket.clone())),
1187                FileState::WRITABLE,
1188                self.supports_sa_restart(),
1189            ));
1190        }
1191
1192        Ok(result?)
1193    }
1194
1195    pub fn recvmsg<W: Write>(
1196        &mut self,
1197        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1198        dst: W,
1199        mut flags: MsgFlags,
1200        _mem: &mut MemoryManager,
1201        cb_queue: &mut CallbackQueue,
1202    ) -> Result<(usize, usize), SyscallError> {
1203        let supported_flags = MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_PEEK | MsgFlags::MSG_TRUNC;
1204
1205        // if there's a flag we don't support, it's probably best to raise an error rather than do
1206        // the wrong thing
1207        if flags.intersects(!supported_flags) {
1208            warn_once_then_debug!("Unsupported recv flags: {:?}", flags);
1209            return Err(Errno::EINVAL.into());
1210        }
1211
1212        if self.status.contains(FileStatus::NONBLOCK) {
1213            flags.insert(MsgFlags::MSG_DONTWAIT);
1214        }
1215
1216        // run in a closure so that an early return doesn't return from the syscall handler
1217        let result = (|| {
1218            let mut buffer = self.buffer.borrow_mut();
1219
1220            // the read would block if the buffer has no data
1221            if !buffer.has_data() {
1222                return Err(Errno::EWOULDBLOCK);
1223            }
1224
1225            let (num_copied, num_removed_from_buf) = if flags.contains(MsgFlags::MSG_PEEK) {
1226                buffer.peek(dst).map_err(|e| Errno::try_from(e).unwrap())?
1227            } else {
1228                buffer
1229                    .read(dst, cb_queue)
1230                    .map_err(|e| Errno::try_from(e).unwrap())?
1231            };
1232
1233            if flags.contains(MsgFlags::MSG_TRUNC) {
1234                // return the total size of the message, not the number of bytes we read
1235                Ok((num_removed_from_buf, num_removed_from_buf))
1236            } else {
1237                Ok((num_copied, num_removed_from_buf))
1238            }
1239        })();
1240
1241        // if the syscall would block and we don't have the MSG_DONTWAIT flag
1242        if result.as_ref().err() == Some(&Errno::EWOULDBLOCK)
1243            && !flags.contains(MsgFlags::MSG_DONTWAIT)
1244        {
1245            return Err(SyscallError::new_blocked_on_file(
1246                File::Socket(Socket::Netlink(socket.clone())),
1247                FileState::READABLE,
1248                self.supports_sa_restart(),
1249            ));
1250        }
1251
1252        Ok(result?)
1253    }
1254
1255    fn update_state(
1256        &mut self,
1257        mask: FileState,
1258        state: FileState,
1259        signals: FileSignals,
1260        cb_queue: &mut CallbackQueue,
1261    ) {
1262        let old_state = self.state;
1263
1264        // remove the masked flags, then copy the masked flags
1265        self.state.remove(mask);
1266        self.state.insert(state & mask);
1267
1268        self.handle_state_change(old_state, signals, cb_queue);
1269    }
1270
1271    fn handle_state_change(
1272        &mut self,
1273        old_state: FileState,
1274        signals: FileSignals,
1275        cb_queue: &mut CallbackQueue,
1276    ) {
1277        let states_changed = self.state ^ old_state;
1278
1279        // if nothing changed
1280        if states_changed.is_empty() && signals.is_empty() {
1281            return;
1282        }
1283
1284        self.event_source
1285            .notify_listeners(self.state, states_changed, signals, cb_queue);
1286    }
1287}
1288
1289#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
1290pub enum NetlinkSocketType {
1291    Dgram,
1292    Raw,
1293}
1294
1295impl TryFrom<libc::c_int> for NetlinkSocketType {
1296    type Error = NetlinkSocketTypeConversionError;
1297    fn try_from(val: libc::c_int) -> Result<Self, Self::Error> {
1298        match val {
1299            libc::SOCK_DGRAM => Ok(Self::Dgram),
1300            libc::SOCK_RAW => Ok(Self::Raw),
1301            x => Err(NetlinkSocketTypeConversionError(x)),
1302        }
1303    }
1304}
1305
1306#[derive(Copy, Clone, Debug)]
1307pub struct NetlinkSocketTypeConversionError(libc::c_int);
1308
1309impl std::error::Error for NetlinkSocketTypeConversionError {}
1310
1311impl std::fmt::Display for NetlinkSocketTypeConversionError {
1312    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1313        write!(
1314            f,
1315            "Invalid socket type {}; netlink sockets only support SOCK_DGRAM and SOCK_RAW",
1316            self.0
1317        )
1318    }
1319}
1320
1321#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
1322pub enum NetlinkFamily {
1323    Route,
1324}
1325
1326impl TryFrom<libc::c_int> for NetlinkFamily {
1327    type Error = NetlinkFamilyConversionError;
1328    fn try_from(val: libc::c_int) -> Result<Self, Self::Error> {
1329        match val {
1330            libc::NETLINK_ROUTE => Ok(Self::Route),
1331            x => Err(NetlinkFamilyConversionError(x)),
1332        }
1333    }
1334}
1335
1336#[derive(Copy, Clone, Debug)]
1337pub struct NetlinkFamilyConversionError(libc::c_int);
1338
1339impl std::error::Error for NetlinkFamilyConversionError {}
1340
1341impl std::fmt::Display for NetlinkFamilyConversionError {
1342    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1343        write!(
1344            f,
1345            "Invalid netlink family {}; netlink families only support NETLINK_ROUTE",
1346            self.0
1347        )
1348    }
1349}