1use std::io::{Cursor, ErrorKind, Read, Write};
2use std::net::Ipv4Addr;
3use std::sync::{Arc, Weak};
4
5use atomic_refcell::AtomicRefCell;
6use linux_api::errno::Errno;
7use linux_api::ioctls::IoctlRequest;
8use linux_api::netlink::{ifaddrmsg, ifinfomsg, nlmsghdr};
9use linux_api::rtnetlink::{RTM_GETADDR, RTM_GETLINK, RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR};
10use linux_api::socket::Shutdown;
11use neli::consts::nl::{NlmF, Nlmsg};
12use neli::consts::rtnl::{Arphrd, Ifa, IfaF, Iff, Ifla, RtAddrFamily, RtScope, Rtm};
13use neli::nl::{NlPayload, Nlmsghdr, NlmsghdrBuilder};
14use neli::rtnl::{Ifaddrmsg, IfaddrmsgBuilder, Ifinfomsg, IfinfomsgBuilder, RtattrBuilder};
15use neli::types::{Buffer, RtBuffer};
16use neli::{FromBytes, ToBytes};
17use nix::sys::socket::{MsgFlags, NetlinkAddr};
18use shadow_shim_helper_rs::syscall_types::ForeignPtr;
19
20use crate::core::worker::Worker;
21use crate::cshadow as c;
22use crate::host::descriptor::listener::{StateEventSource, StateListenHandle, StateListenerFilter};
23use crate::host::descriptor::shared_buf::{
24    BufferHandle, BufferSignals, BufferState, ReaderHandle, SharedBuf,
25};
26use crate::host::descriptor::socket::{RecvmsgArgs, RecvmsgReturn, SendmsgArgs, Socket};
27use crate::host::descriptor::{
28    File, FileMode, FileSignals, FileState, FileStatus, OpenFile, SyscallResult,
29};
30use crate::host::memory_manager::MemoryManager;
31use crate::host::network::namespace::NetworkNamespace;
32use crate::host::syscall::io::{IoVec, IoVecReader, IoVecWriter};
33use crate::host::syscall::types::SyscallError;
34use crate::utility::HostTreePointer;
35use crate::utility::callback_queue::CallbackQueue;
36use crate::utility::sockaddr::SockaddrStorage;
37
38const NETLINK_SOCKET_DEFAULT_BUFFER_SIZE: u64 = 212_992;
40
41pub struct NetlinkSocket {
42    common: NetlinkSocketCommon,
44    protocol_state: ProtocolState,
46}
47
48impl NetlinkSocket {
49    pub fn new(
50        status: FileStatus,
51        _socket_type: NetlinkSocketType,
52        _family: NetlinkFamily,
53    ) -> Arc<AtomicRefCell<Self>> {
54        Arc::new_cyclic(|weak| {
55            let buffer = SharedBuf::new(usize::MAX);
57            let buffer = Arc::new(AtomicRefCell::new(buffer));
58
59            let default_ip = Worker::with_active_host(|host| host.default_ip()).unwrap();
61            let interfaces = vec![
63                Interface {
64                    address: Ipv4Addr::LOCALHOST,
65                    label: String::from("lo"),
66                    prefix_len: 8,
67                    if_type: Arphrd::Loopback,
68                    mtu: c::CONFIG_MTU,
69                    scope: RtScope::Host,
70                    index: 1,
71                },
72                Interface {
73                    address: default_ip,
74                    label: String::from("eth0"),
75                    prefix_len: 24,
76                    if_type: Arphrd::Ether,
77                    mtu: c::CONFIG_MTU,
78                    scope: RtScope::Universe,
79                    index: 2,
80                },
81            ];
82
83            let mut common = NetlinkSocketCommon {
84                buffer,
85                send_limit: NETLINK_SOCKET_DEFAULT_BUFFER_SIZE,
86                sent_len: 0,
87                event_source: StateEventSource::new(),
88                state: FileState::ACTIVE,
89                status,
90                has_open_file: false,
91                interfaces,
92            };
93            let protocol_state = ProtocolState::new(&mut common, weak);
94            let mut socket = Self {
95                common,
96                protocol_state,
97            };
98            CallbackQueue::queue_and_run_with_legacy(|cb_queue| {
99                socket.refresh_file_state(FileSignals::empty(), cb_queue)
100            });
101
102            AtomicRefCell::new(socket)
103        })
104    }
105
106    pub fn status(&self) -> FileStatus {
107        self.common.status
108    }
109
110    pub fn set_status(&mut self, status: FileStatus) {
111        self.common.status = status;
112    }
113
114    pub fn mode(&self) -> FileMode {
115        FileMode::READ | FileMode::WRITE
116    }
117
118    pub fn has_open_file(&self) -> bool {
119        self.common.has_open_file
120    }
121
122    pub fn supports_sa_restart(&self) -> bool {
123        self.common.supports_sa_restart()
124    }
125
126    pub fn set_has_open_file(&mut self, val: bool) {
127        self.common.has_open_file = val;
128    }
129
130    pub fn getsockname(&self) -> Result<Option<nix::sys::socket::NetlinkAddr>, Errno> {
131        self.protocol_state.bound_address()
132    }
133
134    pub fn getpeername(&self) -> Result<Option<nix::sys::socket::NetlinkAddr>, Errno> {
135        warn_once_then_debug!(
136            "getpeername() syscall not yet supported for netlink sockets; Returning ENOSYS"
137        );
138        Err(Errno::ENOSYS)
139    }
140
141    pub fn address_family(&self) -> linux_api::socket::AddressFamily {
142        linux_api::socket::AddressFamily::AF_NETLINK
143    }
144
145    pub fn close(&mut self, cb_queue: &mut CallbackQueue) -> Result<(), SyscallError> {
146        self.protocol_state.close(&mut self.common, cb_queue)
147    }
148
149    fn refresh_file_state(&mut self, signals: FileSignals, cb_queue: &mut CallbackQueue) {
150        self.protocol_state
151            .refresh_file_state(&mut self.common, signals, cb_queue)
152    }
153
154    pub fn shutdown(
155        &mut self,
156        _how: Shutdown,
157        _cb_queue: &mut CallbackQueue,
158    ) -> Result<(), SyscallError> {
159        warn_once_then_debug!(
160            "shutdown() syscall not yet supported for netlink sockets; Returning ENOSYS"
161        );
162        Err(Errno::ENOSYS.into())
163    }
164
165    pub fn getsockopt(
166        &mut self,
167        _level: libc::c_int,
168        _optname: libc::c_int,
169        _optval_ptr: ForeignPtr<()>,
170        _optlen: libc::socklen_t,
171        _memory_manager: &mut MemoryManager,
172        _cb_queue: &mut CallbackQueue,
173    ) -> Result<libc::socklen_t, SyscallError> {
174        warn_once_then_debug!(
175            "getsockopt() syscall not yet supported for netlink sockets; Returning ENOSYS"
176        );
177        Err(Errno::ENOSYS.into())
178    }
179
180    pub fn setsockopt(
181        &mut self,
182        level: libc::c_int,
183        optname: libc::c_int,
184        optval_ptr: ForeignPtr<()>,
185        optlen: libc::socklen_t,
186        memory_manager: &MemoryManager,
187    ) -> Result<(), SyscallError> {
188        match (level, optname) {
189            (libc::SOL_SOCKET, libc::SO_SNDBUF) => {
190                type OptType = libc::c_int;
191
192                if usize::try_from(optlen).unwrap() < std::mem::size_of::<OptType>() {
193                    return Err(Errno::EINVAL.into());
194                }
195
196                let optval_ptr = optval_ptr.cast::<OptType>();
197                let val: u64 = memory_manager
198                    .read(optval_ptr)?
199                    .try_into()
200                    .or(Err(Errno::EINVAL))?;
201
202                let val = val * 2;
204                let val = std::cmp::max(val, self.common.sent_len);
206                let val = std::cmp::max(val, 4096);
208                let val = std::cmp::min(val, 268435456); self.common.send_limit = val;
211            }
212            (libc::SOL_SOCKET, libc::SO_RCVBUF) => {
213                }
217            _ => {
218                warn_once_then_debug!(
219                    "setsockopt called with unsupported level {level} and opt {optname}"
220                );
221                return Err(Errno::ENOPROTOOPT.into());
222            }
223        }
224
225        Ok(())
226    }
227
228    pub fn bind(
229        socket: &Arc<AtomicRefCell<Self>>,
230        addr: Option<&SockaddrStorage>,
231        _net_ns: &NetworkNamespace,
232        rng: impl rand::Rng,
233    ) -> Result<(), SyscallError> {
234        let socket_ref = &mut *socket.borrow_mut();
235        socket_ref
236            .protocol_state
237            .bind(&mut socket_ref.common, socket, addr, rng)
238    }
239
240    pub fn readv(
241        &mut self,
242        _iovs: &[IoVec],
243        _offset: Option<libc::off_t>,
244        _flags: libc::c_int,
245        _mem: &mut MemoryManager,
246        _cb_queue: &mut CallbackQueue,
247    ) -> Result<libc::ssize_t, SyscallError> {
248        panic!("Called NetlinkSocket::readv() on a netlink socket.");
252    }
253
254    pub fn writev(
255        &mut self,
256        _iovs: &[IoVec],
257        _offset: Option<libc::off_t>,
258        _flags: libc::c_int,
259        _mem: &mut MemoryManager,
260        _cb_queue: &mut CallbackQueue,
261    ) -> Result<libc::ssize_t, SyscallError> {
262        panic!("Called NetlinkSocket::writev() on a netlink socket");
266    }
267
268    pub fn sendmsg(
269        socket: &Arc<AtomicRefCell<Self>>,
270        args: SendmsgArgs,
271        mem: &mut MemoryManager,
272        _net_ns: &NetworkNamespace,
273        _rng: impl rand::Rng,
274        cb_queue: &mut CallbackQueue,
275    ) -> Result<libc::ssize_t, SyscallError> {
276        let socket_ref = &mut *socket.borrow_mut();
277        socket_ref
278            .protocol_state
279            .sendmsg(&mut socket_ref.common, socket, args, mem, cb_queue)
280    }
281
282    pub fn recvmsg(
283        socket: &Arc<AtomicRefCell<Self>>,
284        args: RecvmsgArgs,
285        mem: &mut MemoryManager,
286        cb_queue: &mut CallbackQueue,
287    ) -> Result<RecvmsgReturn, SyscallError> {
288        let socket_ref = &mut *socket.borrow_mut();
289        socket_ref
290            .protocol_state
291            .recvmsg(&mut socket_ref.common, socket, args, mem, cb_queue)
292    }
293
294    pub fn listen(
295        _socket: &Arc<AtomicRefCell<Self>>,
296        _backlog: i32,
297        _net_ns: &NetworkNamespace,
298        _rng: impl rand::Rng,
299        _cb_queue: &mut CallbackQueue,
300    ) -> Result<(), Errno> {
301        warn_once_then_debug!("We do not yet handle listen request on netlink sockets");
302        Err(Errno::EINVAL)
303    }
304
305    pub fn connect(
306        _socket: &Arc<AtomicRefCell<Self>>,
307        _addr: &SockaddrStorage,
308        _net_ns: &NetworkNamespace,
309        _rng: impl rand::Rng,
310        _cb_queue: &mut CallbackQueue,
311    ) -> Result<(), SyscallError> {
312        warn_once_then_debug!("We do not yet handle connect request on netlink sockets");
313        Err(Errno::EINVAL.into())
314    }
315
316    pub fn accept(
317        &mut self,
318        _net_ns: &NetworkNamespace,
319        _rng: impl rand::Rng,
320        _cb_queue: &mut CallbackQueue,
321    ) -> Result<OpenFile, SyscallError> {
322        warn_once_then_debug!("We do not yet handle accept request on netlink sockets");
323        Err(Errno::EINVAL.into())
324    }
325
326    pub fn ioctl(
327        &mut self,
328        request: IoctlRequest,
329        _arg_ptr: ForeignPtr<()>,
330        _memory_manager: &mut MemoryManager,
331    ) -> SyscallResult {
332        warn_once_then_debug!("We do not yet handle ioctl request {request:?} on netlink sockets");
333        Err(Errno::EINVAL.into())
334    }
335
336    pub fn stat(&self) -> Result<linux_api::stat::stat, SyscallError> {
337        warn_once_then_debug!("We do not yet handle stat calls on netlink sockets");
338        Err(Errno::EINVAL.into())
339    }
340
341    pub fn add_listener(
342        &mut self,
343        monitoring_state: FileState,
344        monitoring_signals: FileSignals,
345        filter: StateListenerFilter,
346        notify_fn: impl Fn(FileState, FileState, FileSignals, &mut CallbackQueue)
347        + Send
348        + Sync
349        + 'static,
350    ) -> StateListenHandle {
351        self.common.event_source.add_listener(
352            monitoring_state,
353            monitoring_signals,
354            filter,
355            notify_fn,
356        )
357    }
358
359    pub fn add_legacy_listener(&mut self, ptr: HostTreePointer<c::StatusListener>) {
360        self.common.event_source.add_legacy_listener(ptr);
361    }
362
363    pub fn remove_legacy_listener(&mut self, ptr: *mut c::StatusListener) {
364        self.common.event_source.remove_legacy_listener(ptr);
365    }
366
367    pub fn state(&self) -> FileState {
368        self.common.state
369    }
370}
371
372struct InitialState {
373    bound_addr: Option<NetlinkAddr>,
374    reader_handle: ReaderHandle,
375    _buffer_handle: BufferHandle,
377}
378struct ClosedState {}
379enum ProtocolState {
383    Initial(Option<InitialState>),
384    Closed(Option<ClosedState>),
385}
386
387macro_rules! state_upcast {
389    ($type:ty, $parent:ident::$variant:ident) => {
390        impl From<$type> for $parent {
391            fn from(x: $type) -> Self {
392                Self::$variant(Some(x))
393            }
394        }
395    };
396}
397
398state_upcast!(InitialState, ProtocolState::Initial);
400state_upcast!(ClosedState, ProtocolState::Closed);
401
402impl ProtocolState {
403    fn new(common: &mut NetlinkSocketCommon, socket: &Weak<AtomicRefCell<NetlinkSocket>>) -> Self {
404        let mut cb_queue = CallbackQueue::new();
406
407        let reader_handle = common.buffer.borrow_mut().add_reader(&mut cb_queue);
409
410        let weak = Weak::clone(socket);
411        let buffer_handle = common.buffer.borrow_mut().add_listener(
412            BufferState::READABLE,
413            BufferSignals::BUFFER_GREW,
414            move |_, signals, cb_queue| {
415                if let Some(socket) = weak.upgrade() {
416                    let signals = if signals.contains(BufferSignals::BUFFER_GREW) {
417                        FileSignals::READ_BUFFER_GREW
418                    } else {
419                        FileSignals::empty()
420                    };
421
422                    socket.borrow_mut().refresh_file_state(signals, cb_queue);
423                }
424            },
425        );
426
427        ProtocolState::Initial(Some(InitialState {
428            bound_addr: None,
429            reader_handle,
430            _buffer_handle: buffer_handle,
431        }))
432    }
433
434    fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
435        match self {
436            Self::Initial(x) => x.as_ref().unwrap().bound_address(),
437            Self::Closed(x) => x.as_ref().unwrap().bound_address(),
438        }
439    }
440
441    fn refresh_file_state(
442        &self,
443        common: &mut NetlinkSocketCommon,
444        signals: FileSignals,
445        cb_queue: &mut CallbackQueue,
446    ) {
447        match self {
448            Self::Initial(x) => x
449                .as_ref()
450                .unwrap()
451                .refresh_file_state(common, signals, cb_queue),
452            Self::Closed(x) => x
453                .as_ref()
454                .unwrap()
455                .refresh_file_state(common, signals, cb_queue),
456        }
457    }
458
459    fn close(
460        &mut self,
461        common: &mut NetlinkSocketCommon,
462        cb_queue: &mut CallbackQueue,
463    ) -> Result<(), SyscallError> {
464        let (new_state, rv) = match self {
465            Self::Initial(x) => x.take().unwrap().close(common, cb_queue),
466            Self::Closed(x) => x.take().unwrap().close(common, cb_queue),
467        };
468
469        *self = new_state;
470        rv
471    }
472
473    fn bind(
474        &mut self,
475        common: &mut NetlinkSocketCommon,
476        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
477        addr: Option<&SockaddrStorage>,
478        rng: impl rand::Rng,
479    ) -> Result<(), SyscallError> {
480        match self {
481            Self::Initial(x) => x.as_mut().unwrap().bind(common, socket, addr, rng),
482            Self::Closed(x) => x.as_mut().unwrap().bind(common, socket, addr, rng),
483        }
484    }
485
486    fn sendmsg(
487        &mut self,
488        common: &mut NetlinkSocketCommon,
489        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
490        args: SendmsgArgs,
491        mem: &mut MemoryManager,
492        cb_queue: &mut CallbackQueue,
493    ) -> Result<libc::ssize_t, SyscallError> {
494        match self {
495            Self::Initial(x) => x
496                .as_mut()
497                .unwrap()
498                .sendmsg(common, socket, args, mem, cb_queue),
499            Self::Closed(x) => x
500                .as_mut()
501                .unwrap()
502                .sendmsg(common, socket, args, mem, cb_queue),
503        }
504    }
505
506    fn recvmsg(
507        &mut self,
508        common: &mut NetlinkSocketCommon,
509        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
510        args: RecvmsgArgs,
511        mem: &mut MemoryManager,
512        cb_queue: &mut CallbackQueue,
513    ) -> Result<RecvmsgReturn, SyscallError> {
514        match self {
515            Self::Initial(x) => x
516                .as_mut()
517                .unwrap()
518                .recvmsg(common, socket, args, mem, cb_queue),
519            Self::Closed(x) => x
520                .as_mut()
521                .unwrap()
522                .recvmsg(common, socket, args, mem, cb_queue),
523        }
524    }
525}
526
527impl InitialState {
528    fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
529        Ok(self.bound_addr)
530    }
531
532    fn refresh_file_state(
533        &self,
534        common: &mut NetlinkSocketCommon,
535        signals: FileSignals,
536        cb_queue: &mut CallbackQueue,
537    ) {
538        let mut new_state = FileState::ACTIVE;
539
540        {
541            let buffer = common.buffer.borrow();
542
543            new_state.set(FileState::READABLE, buffer.has_data());
544            new_state.set(FileState::WRITABLE, common.sent_len < common.send_limit);
545        }
546
547        common.update_state(
548            FileState::all(),
549            new_state,
550            signals,
551            cb_queue,
552        );
553    }
554
555    fn close(
556        self,
557        common: &mut NetlinkSocketCommon,
558        cb_queue: &mut CallbackQueue,
559    ) -> (ProtocolState, Result<(), SyscallError>) {
560        common
562            .buffer
563            .borrow_mut()
564            .remove_reader(self.reader_handle, cb_queue);
565
566        let new_state = ClosedState {};
567        new_state.refresh_file_state(common, FileSignals::empty(), cb_queue);
568        (new_state.into(), Ok(()))
569    }
570
571    fn bind(
572        &mut self,
573        _common: &mut NetlinkSocketCommon,
574        _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
575        addr: Option<&SockaddrStorage>,
576        _rng: impl rand::Rng,
577    ) -> Result<(), SyscallError> {
578        if self.bound_addr.is_some() {
580            return Err(Errno::EINVAL.into());
581        }
582        if addr.is_none() {
584            return Err(Errno::EFAULT.into());
585        }
586
587        let Some(addr) = addr.and_then(|x| x.as_netlink()) else {
589            log::warn!("Attempted to bind netlink socket to non-netlink address {addr:?}");
590            return Err(Errno::EINVAL.into());
591        };
592
593        self.bound_addr = Some(*addr);
597
598        if (addr.groups() & !(RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR)) != 0 {
601            log::warn!(
602                "Attempted to bind netlink socket to an address with unsupported groups {}",
603                addr.groups()
604            );
605            return Err(Errno::EINVAL.into());
606        }
607
608        Ok(())
609    }
610
611    fn sendmsg(
612        &mut self,
613        common: &mut NetlinkSocketCommon,
614        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
615        args: SendmsgArgs,
616        mem: &mut MemoryManager,
617        cb_queue: &mut CallbackQueue,
618    ) -> Result<libc::ssize_t, SyscallError> {
619        if !args.control_ptr.ptr().is_null() {
620            log::debug!("Netlink sockets don't yet support control data for sendmsg()");
621            return Err(Errno::EINVAL.into());
622        }
623
624        if let Some(addr) = args.addr {
626            let Some(addr) = addr.as_netlink() else {
628                log::warn!("Attempted to send to non-netlink address {:?}", args.addr);
629                return Err(Errno::EINVAL.into());
630            };
631            if addr.pid() != 0 {
633                log::warn!("Attempted to send to non-kernel netlink address {addr:?}");
634                return Err(Errno::EINVAL.into());
635            }
636            if addr.groups() != 0 {
638                log::warn!("Attempted to send to netlink groups {addr:?}");
639                return Err(Errno::EINVAL.into());
640            }
641        }
642
643        let rv = common.sendmsg(socket, args.iovs, args.flags, mem, cb_queue)?;
644
645        self.refresh_file_state(common, FileSignals::empty(), cb_queue);
646
647        Ok(rv.try_into().unwrap())
648    }
649
650    fn recvmsg(
651        &mut self,
652        common: &mut NetlinkSocketCommon,
653        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
654        args: RecvmsgArgs,
655        mem: &mut MemoryManager,
656        cb_queue: &mut CallbackQueue,
657    ) -> Result<RecvmsgReturn, SyscallError> {
658        if !args.control_ptr.ptr().is_null() {
659            log::debug!("Netlink sockets don't yet support control data for recvmsg()");
660            return Err(Errno::EINVAL.into());
661        }
662        let Some(flags) = MsgFlags::from_bits(args.flags) else {
663            warn_once_then_debug!("Unrecognized recv flags: {:#b}", args.flags);
664            return Err(Errno::EINVAL.into());
665        };
666
667        let mut packet_buffer = Vec::new();
668        let (_rv, _num_removed_from_buf) =
669            common.recvmsg(socket, &mut packet_buffer, flags, mem, cb_queue)?;
670        self.refresh_file_state(common, FileSignals::empty(), cb_queue);
671
672        let mut writer = IoVecWriter::new(args.iovs, mem);
673
674        let src_addr = SockaddrStorage::from_netlink(&NetlinkAddr::new(0, 0));
676
677        if packet_buffer.len() < std::mem::size_of::<nlmsghdr>() {
678            log::warn!("The processed packet is too short");
679            return Err(Errno::EINVAL.into());
680        }
681
682        let buffer = {
683            let nlmsg_type = &packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_type)];
684            let nlmsg_type = u16::from_ne_bytes(nlmsg_type.try_into().unwrap());
685
686            match nlmsg_type {
687                RTM_GETLINK => {
688                    let nlmsghdr_len = std::mem::size_of::<nlmsghdr>();
689                    let ifinfomsg_len = std::mem::size_of::<ifinfomsg>();
690                    let header_len = nlmsghdr_len + ifinfomsg_len;
691
692                    if (nlmsghdr_len..header_len).contains(&packet_buffer.len()) {
698                        log::debug!(
699                            "Padding the RTM_GETLINK with zeroes to meet the minimum length"
700                        );
701                        packet_buffer.resize(header_len, 0);
702                        packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_len)]
703                            .copy_from_slice(&(header_len as u32).to_ne_bytes()[..]);
704                    }
705                    self.handle_ifinfomsg(common, &packet_buffer[..])
706                }
707                RTM_GETADDR => {
708                    let nlmsghdr_len = std::mem::size_of::<nlmsghdr>();
709                    let ifaddrmsg_len = std::mem::size_of::<ifaddrmsg>();
710                    let header_len = nlmsghdr_len + ifaddrmsg_len;
711
712                    if (nlmsghdr_len..header_len).contains(&packet_buffer.len()) {
718                        log::debug!(
719                            "Padding the RTM_GETADDR with zeroes to meet the minimum length"
720                        );
721                        packet_buffer.resize(header_len, 0);
722                        packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_len)]
723                            .copy_from_slice(&(header_len as u32).to_ne_bytes()[..]);
724                    }
725                    self.handle_ifaddrmsg(common, &packet_buffer[..])
726                }
727                _ => {
728                    warn_once_then_debug!(
729                        "Found unsupported nlmsg_type: {nlmsg_type} (only RTM_GETLINK
730                        and RTM_GETADDR are supported)"
731                    );
732                    self.handle_error(&packet_buffer[..])
733                }
734            }
735        };
736
737        let mut total_copied = 0;
739        let mut buf = buffer.as_slice();
740        while !buf.is_empty() {
741            match writer.write(buf) {
742                Ok(0) => break,
743                Ok(n) => {
744                    buf = &buf[n..];
745                    total_copied += n;
746                }
747                Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
748                Err(e) => return Err(e.into()),
749            }
750        }
751
752        let return_val = if flags.contains(MsgFlags::MSG_TRUNC) {
753            buffer.len()
754        } else {
755            total_copied
756        };
757
758        Ok(RecvmsgReturn {
759            return_val: return_val.try_into().unwrap(),
760            addr: Some(src_addr),
761            msg_flags: 0,
762            control_len: 0,
763        })
764    }
765
766    fn handle_error(&self, bytes: &[u8]) -> Vec<u8> {
767        let nlmsg_seq = match bytes.get(memoffset::span_of!(nlmsghdr, nlmsg_seq)) {
769            Some(x) => u32::from_ne_bytes(x.try_into().unwrap()),
770            None => 0,
771        };
772
773        let msg = NlmsghdrBuilder::default()
775            .nl_type(Nlmsg::Error)
776            .nl_flags(NlmF::empty())
777            .nl_seq(nlmsg_seq)
778            .nl_payload(NlPayload::<Nlmsg, ()>::Empty)
779            .build()
780            .expect("NlmsghdrBuilder missing a required field");
781
782        let mut buffer = Cursor::new(Vec::new());
783        msg.to_bytes(&mut buffer).unwrap();
784        buffer.into_inner()
785    }
786
787    fn handle_ifaddrmsg(&self, common: &mut NetlinkSocketCommon, bytes: &[u8]) -> Vec<u8> {
788        let Ok(nlmsg) = Nlmsghdr::<Rtm, Ifaddrmsg>::from_bytes(&mut Cursor::new(bytes)) else {
789            log::warn!("Failed to deserialize the message");
790            return self.handle_error(bytes);
791        };
792
793        let Some(ifaddrmsg) = nlmsg.get_payload() else {
794            log::warn!("Failed to find the payload");
795            return self.handle_error(bytes);
796        };
797
798        if *ifaddrmsg.ifa_family() != RtAddrFamily::Unspecified
800            && *ifaddrmsg.ifa_family() != RtAddrFamily::Inet
801        {
802            log::warn!("Unsupported ifa_family (only AF_UNSPEC and AF_INET are supported)");
803            return self.handle_error(bytes);
804        }
805
806        if *ifaddrmsg.ifa_prefixlen() != 0
808            || !ifaddrmsg.ifa_flags().is_empty()
809            || *ifaddrmsg.ifa_index() != 0
810            || *ifaddrmsg.ifa_scope() != RtScope::Universe
811        {
812            log::warn!(
813                "Unsupported ifa_prefixlen, ifa_flags, ifa_scope, or ifa_index (they have to be 0)",
814            );
815            return self.handle_error(bytes);
816        }
817
818        let mut buffer = Cursor::new(Vec::new());
819        for interface in &common.interfaces {
821            let address = interface.address.octets();
822            let broadcast = Ipv4Addr::from(
823                0xffff_ffff_u32
824                    .checked_shr(u32::from(interface.prefix_len))
825                    .unwrap_or(0)
826                    | u32::from(interface.address),
827            )
828            .octets();
829            let mut label = Vec::from(interface.label.as_bytes());
830            label.push(0); let attrs = [
834                RtattrBuilder::default()
838                    .rta_type(Ifa::Address)
839                    .rta_payload(Buffer::from(&address[..]))
840                    .build()
841                    .unwrap(),
842                RtattrBuilder::default()
843                    .rta_type(Ifa::Local)
844                    .rta_payload(Buffer::from(&address[..]))
845                    .build()
846                    .unwrap(),
847                RtattrBuilder::default()
848                    .rta_type(Ifa::Broadcast)
849                    .rta_payload(Buffer::from(&broadcast[..]))
850                    .build()
851                    .unwrap(),
852                RtattrBuilder::default()
853                    .rta_type(Ifa::Label)
854                    .rta_payload(Buffer::from(label))
855                    .build()
856                    .unwrap(),
857            ];
858            let ifaddrmsg = IfaddrmsgBuilder::default()
859                .ifa_family(RtAddrFamily::Inet)
860                .ifa_prefixlen(interface.prefix_len)
861                .ifa_flags(IfaF::PERMANENT)
863                .ifa_scope(interface.scope)
864                .ifa_index(interface.index)
865                .rtattrs(RtBuffer::from_iter(attrs))
866                .build()
867                .expect("IfaddrmsgBuilder missing a required field");
868            let nlmsg = NlmsghdrBuilder::default()
869                .nl_type(Rtm::Newaddr)
870                .nl_flags(NlmF::MULTI)
872                .nl_seq(*nlmsg.nl_seq())
874                .nl_payload(NlPayload::Payload(ifaddrmsg))
875                .build()
876                .expect("NlmsghdrBuilder missing a required field");
877            nlmsg.to_bytes(&mut buffer).unwrap();
878        }
879        let done_msg = NlmsghdrBuilder::default()
881            .nl_type(Nlmsg::Done)
882            .nl_flags(NlmF::MULTI)
883            .nl_seq(*nlmsg.nl_seq())
885            .nl_payload(NlPayload::Payload(0u32))
891            .build()
892            .expect("NlmsghdrBuilder missing a required field");
893        done_msg.to_bytes(&mut buffer).unwrap();
894
895        buffer.into_inner()
896    }
897
898    fn handle_ifinfomsg(&self, common: &mut NetlinkSocketCommon, bytes: &[u8]) -> Vec<u8> {
899        let Ok(nlmsg) = Nlmsghdr::<Rtm, Ifinfomsg>::from_bytes(&mut Cursor::new(bytes)) else {
900            log::warn!("Failed to deserialize the message");
901            return self.handle_error(bytes);
902        };
903
904        let Some(ifinfomsg) = nlmsg.get_payload() else {
905            log::warn!("Failed to find the payload");
906            return self.handle_error(bytes);
907        };
908
909        if *ifinfomsg.ifi_family() != RtAddrFamily::Unspecified
911            && *ifinfomsg.ifi_family() != RtAddrFamily::Inet
912        {
913            warn_once_then_debug!(
914                "Unsupported ifi_family (only AF_UNSPEC and AF_INET are supported)"
915            );
916            return self.handle_error(bytes);
917        }
918
919        if *ifinfomsg.ifi_type() != 0.into()
921            || *ifinfomsg.ifi_index() != 0
922            || !ifinfomsg.ifi_flags().is_empty()
923        {
924            warn_once_then_debug!(
925                "Unsupported ifi_type, ifi_index, or ifi_flags (they have to be 0)"
926            );
927            return self.handle_error(bytes);
928        }
929
930        let mut buffer = Cursor::new(Vec::new());
934        for interface in &common.interfaces {
936            let mut label = Vec::from(interface.label.as_bytes());
937            label.push(0); let attrs = [
941                RtattrBuilder::default()
942                    .rta_type(Ifla::Ifname)
943                    .rta_payload(Buffer::from(label))
944                    .build()
945                    .unwrap(),
946                RtattrBuilder::default()
950                    .rta_type(Ifla::Txqlen)
951                    .rta_payload(Buffer::from(&u32::to_le_bytes(1000)[..]))
952                    .build()
953                    .unwrap(),
954                RtattrBuilder::default()
955                    .rta_type(Ifla::Mtu)
956                    .rta_payload(Buffer::from(&u32::to_le_bytes(interface.mtu)[..]))
957                    .build()
958                    .unwrap(),
959                ];
961            let flags = if interface.if_type == Arphrd::Loopback {
962                Iff::UP | Iff::LOOPBACK | Iff::RUNNING
963            } else {
964                Iff::UP | Iff::BROADCAST | Iff::RUNNING | Iff::MULTICAST
966            };
967            let interface_index = interface
968                .index
969                .try_into()
970                .expect("interface index too large");
971
972            let ifinfomsg = IfinfomsgBuilder::default()
973                .ifi_family(RtAddrFamily::Inet)
974                .ifi_type(interface.if_type)
975                .ifi_index(interface_index)
976                .ifi_flags(flags)
977                .ifi_change(Iff::from_bits_retain(0xffffffff))
979                .rtattrs(RtBuffer::from_iter(attrs))
980                .build()
981                .expect("IfinfomsgBuilder missing a required field");
982            let nlmsg = NlmsghdrBuilder::default()
983                .nl_type(Rtm::Newlink)
984                .nl_flags(NlmF::MULTI)
986                .nl_seq(*nlmsg.nl_seq())
988                .nl_payload(NlPayload::Payload(ifinfomsg))
989                .build()
990                .expect("NlmsghdrBuilder missing a required field");
991            nlmsg.to_bytes(&mut buffer).unwrap();
992        }
993        let done_msg = NlmsghdrBuilder::default()
995            .nl_type(Nlmsg::Done)
996            .nl_flags(NlmF::MULTI)
997            .nl_seq(*nlmsg.nl_seq())
999            .nl_payload(NlPayload::Payload(0u32))
1005            .build()
1006            .expect("NlmsghdrBuilder missing a required field");
1007        done_msg.to_bytes(&mut buffer).unwrap();
1008
1009        buffer.into_inner()
1010    }
1011}
1012
1013impl ClosedState {
1014    fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
1015        Ok(None)
1016    }
1017
1018    fn refresh_file_state(
1019        &self,
1020        common: &mut NetlinkSocketCommon,
1021        signals: FileSignals,
1022        cb_queue: &mut CallbackQueue,
1023    ) {
1024        common.update_state(
1025            FileState::all(),
1026            FileState::CLOSED,
1027            signals,
1028            cb_queue,
1029        );
1030    }
1031
1032    fn close(
1033        self,
1034        _common: &mut NetlinkSocketCommon,
1035        _cb_queue: &mut CallbackQueue,
1036    ) -> (ProtocolState, Result<(), SyscallError>) {
1037        panic!("Trying to close an already closed socket");
1039    }
1040
1041    fn bind(
1042        &mut self,
1043        _common: &mut NetlinkSocketCommon,
1044        _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1045        _addr: Option<&SockaddrStorage>,
1046        _rng: impl rand::Rng,
1047    ) -> Result<(), SyscallError> {
1048        log::warn!("bind() while in state {}", std::any::type_name::<Self>());
1050        Err(Errno::EOPNOTSUPP.into())
1051    }
1052
1053    fn sendmsg(
1054        &mut self,
1055        _common: &mut NetlinkSocketCommon,
1056        _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1057        _args: SendmsgArgs,
1058        _mem: &mut MemoryManager,
1059        _cb_queue: &mut CallbackQueue,
1060    ) -> Result<libc::ssize_t, SyscallError> {
1061        log::warn!("sendmsg() while in state {}", std::any::type_name::<Self>());
1063        Err(Errno::EOPNOTSUPP.into())
1064    }
1065
1066    fn recvmsg(
1067        &mut self,
1068        _common: &mut NetlinkSocketCommon,
1069        _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1070        _args: RecvmsgArgs,
1071        _mem: &mut MemoryManager,
1072        _cb_queue: &mut CallbackQueue,
1073    ) -> Result<RecvmsgReturn, SyscallError> {
1074        log::warn!("recvmsg() while in state {}", std::any::type_name::<Self>());
1076        Err(Errno::EOPNOTSUPP.into())
1077    }
1078}
1079
1080struct Interface {
1082    address: Ipv4Addr,
1083    label: String,
1084    prefix_len: u8,
1085    if_type: Arphrd,
1086    mtu: u32,
1087    scope: RtScope,
1088    index: libc::c_uint,
1089}
1090
1091struct NetlinkSocketCommon {
1093    buffer: Arc<AtomicRefCell<SharedBuf>>,
1094    send_limit: u64,
1096    sent_len: u64,
1098    event_source: StateEventSource,
1099    state: FileState,
1100    status: FileStatus,
1101    has_open_file: bool,
1104    interfaces: Vec<Interface>,
1106}
1107
1108impl NetlinkSocketCommon {
1109    pub fn supports_sa_restart(&self) -> bool {
1110        true
1111    }
1112
1113    pub fn sendmsg(
1114        &mut self,
1115        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1116        iovs: &[IoVec],
1117        flags: libc::c_int,
1118        mem: &mut MemoryManager,
1119        cb_queue: &mut CallbackQueue,
1120    ) -> Result<usize, SyscallError> {
1121        let supported_flags = MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_NOSIGNAL | MsgFlags::MSG_TRUNC;
1124
1125        let Some(mut flags) = MsgFlags::from_bits(flags) else {
1128            warn_once_then_debug!("Unrecognized send flags: {:#b}", flags);
1129            return Err(Errno::EINVAL.into());
1130        };
1131        if flags.intersects(!supported_flags) {
1132            warn_once_then_debug!("Unsupported send flags: {:?}", flags);
1133            return Err(Errno::EINVAL.into());
1134        }
1135
1136        if self.status.contains(FileStatus::NONBLOCK) {
1137            flags.insert(MsgFlags::MSG_DONTWAIT);
1138        }
1139
1140        let result = (|| {
1142            let len = iovs.iter().map(|x| x.len).sum::<libc::size_t>();
1143
1144            let space_available = self
1147                .send_limit
1148                .saturating_sub(self.sent_len)
1149                .try_into()
1150                .unwrap();
1151
1152            if space_available == 0 {
1153                return Err(Errno::EAGAIN);
1154            }
1155
1156            if len > space_available {
1157                if len <= self.send_limit.try_into().unwrap() {
1158                    return Err(Errno::EAGAIN);
1160                } else {
1161                    return Err(Errno::EMSGSIZE);
1163                }
1164            }
1165
1166            let reader = IoVecReader::new(iovs, mem);
1167            let reader = reader.take(len.try_into().unwrap());
1168
1169            self.buffer
1172                .borrow_mut()
1173                .write_packet(reader, len, cb_queue)
1174                .map_err(|e| Errno::try_from(e).unwrap())?;
1175
1176            self.sent_len += u64::try_from(len).unwrap();
1178            Ok(len)
1179        })();
1180
1181        if result.as_ref().err() == Some(&Errno::EWOULDBLOCK)
1183            && !flags.contains(MsgFlags::MSG_DONTWAIT)
1184        {
1185            return Err(SyscallError::new_blocked_on_file(
1186                File::Socket(Socket::Netlink(socket.clone())),
1187                FileState::WRITABLE,
1188                self.supports_sa_restart(),
1189            ));
1190        }
1191
1192        Ok(result?)
1193    }
1194
1195    pub fn recvmsg<W: Write>(
1196        &mut self,
1197        socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1198        dst: W,
1199        mut flags: MsgFlags,
1200        _mem: &mut MemoryManager,
1201        cb_queue: &mut CallbackQueue,
1202    ) -> Result<(usize, usize), SyscallError> {
1203        let supported_flags = MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_PEEK | MsgFlags::MSG_TRUNC;
1204
1205        if flags.intersects(!supported_flags) {
1208            warn_once_then_debug!("Unsupported recv flags: {:?}", flags);
1209            return Err(Errno::EINVAL.into());
1210        }
1211
1212        if self.status.contains(FileStatus::NONBLOCK) {
1213            flags.insert(MsgFlags::MSG_DONTWAIT);
1214        }
1215
1216        let result = (|| {
1218            let mut buffer = self.buffer.borrow_mut();
1219
1220            if !buffer.has_data() {
1222                return Err(Errno::EWOULDBLOCK);
1223            }
1224
1225            let (num_copied, num_removed_from_buf) = if flags.contains(MsgFlags::MSG_PEEK) {
1226                buffer.peek(dst).map_err(|e| Errno::try_from(e).unwrap())?
1227            } else {
1228                buffer
1229                    .read(dst, cb_queue)
1230                    .map_err(|e| Errno::try_from(e).unwrap())?
1231            };
1232
1233            if flags.contains(MsgFlags::MSG_TRUNC) {
1234                Ok((num_removed_from_buf, num_removed_from_buf))
1236            } else {
1237                Ok((num_copied, num_removed_from_buf))
1238            }
1239        })();
1240
1241        if result.as_ref().err() == Some(&Errno::EWOULDBLOCK)
1243            && !flags.contains(MsgFlags::MSG_DONTWAIT)
1244        {
1245            return Err(SyscallError::new_blocked_on_file(
1246                File::Socket(Socket::Netlink(socket.clone())),
1247                FileState::READABLE,
1248                self.supports_sa_restart(),
1249            ));
1250        }
1251
1252        Ok(result?)
1253    }
1254
1255    fn update_state(
1256        &mut self,
1257        mask: FileState,
1258        state: FileState,
1259        signals: FileSignals,
1260        cb_queue: &mut CallbackQueue,
1261    ) {
1262        let old_state = self.state;
1263
1264        self.state.remove(mask);
1266        self.state.insert(state & mask);
1267
1268        self.handle_state_change(old_state, signals, cb_queue);
1269    }
1270
1271    fn handle_state_change(
1272        &mut self,
1273        old_state: FileState,
1274        signals: FileSignals,
1275        cb_queue: &mut CallbackQueue,
1276    ) {
1277        let states_changed = self.state ^ old_state;
1278
1279        if states_changed.is_empty() && signals.is_empty() {
1281            return;
1282        }
1283
1284        self.event_source
1285            .notify_listeners(self.state, states_changed, signals, cb_queue);
1286    }
1287}
1288
1289#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
1290pub enum NetlinkSocketType {
1291    Dgram,
1292    Raw,
1293}
1294
1295impl TryFrom<libc::c_int> for NetlinkSocketType {
1296    type Error = NetlinkSocketTypeConversionError;
1297    fn try_from(val: libc::c_int) -> Result<Self, Self::Error> {
1298        match val {
1299            libc::SOCK_DGRAM => Ok(Self::Dgram),
1300            libc::SOCK_RAW => Ok(Self::Raw),
1301            x => Err(NetlinkSocketTypeConversionError(x)),
1302        }
1303    }
1304}
1305
1306#[derive(Copy, Clone, Debug)]
1307pub struct NetlinkSocketTypeConversionError(libc::c_int);
1308
1309impl std::error::Error for NetlinkSocketTypeConversionError {}
1310
1311impl std::fmt::Display for NetlinkSocketTypeConversionError {
1312    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1313        write!(
1314            f,
1315            "Invalid socket type {}; netlink sockets only support SOCK_DGRAM and SOCK_RAW",
1316            self.0
1317        )
1318    }
1319}
1320
1321#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
1322pub enum NetlinkFamily {
1323    Route,
1324}
1325
1326impl TryFrom<libc::c_int> for NetlinkFamily {
1327    type Error = NetlinkFamilyConversionError;
1328    fn try_from(val: libc::c_int) -> Result<Self, Self::Error> {
1329        match val {
1330            libc::NETLINK_ROUTE => Ok(Self::Route),
1331            x => Err(NetlinkFamilyConversionError(x)),
1332        }
1333    }
1334}
1335
1336#[derive(Copy, Clone, Debug)]
1337pub struct NetlinkFamilyConversionError(libc::c_int);
1338
1339impl std::error::Error for NetlinkFamilyConversionError {}
1340
1341impl std::fmt::Display for NetlinkFamilyConversionError {
1342    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1343        write!(
1344            f,
1345            "Invalid netlink family {}; netlink families only support NETLINK_ROUTE",
1346            self.0
1347        )
1348    }
1349}