1use std::io::{Cursor, ErrorKind, Read, Write};
2use std::net::Ipv4Addr;
3use std::sync::{Arc, Weak};
4
5use atomic_refcell::AtomicRefCell;
6use linux_api::errno::Errno;
7use linux_api::ioctls::IoctlRequest;
8use linux_api::netlink::{ifaddrmsg, ifinfomsg, nlmsghdr};
9use linux_api::rtnetlink::{RTM_GETADDR, RTM_GETLINK, RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR};
10use linux_api::socket::Shutdown;
11use neli::consts::nl::{NlmF, Nlmsg};
12use neli::consts::rtnl::{Arphrd, Ifa, IfaF, Iff, Ifla, RtAddrFamily, RtScope, Rtm};
13use neli::nl::{NlPayload, Nlmsghdr, NlmsghdrBuilder};
14use neli::rtnl::{Ifaddrmsg, IfaddrmsgBuilder, Ifinfomsg, IfinfomsgBuilder, RtattrBuilder};
15use neli::types::{Buffer, RtBuffer};
16use neli::{FromBytes, ToBytes};
17use nix::sys::socket::{MsgFlags, NetlinkAddr};
18use shadow_shim_helper_rs::syscall_types::ForeignPtr;
19
20use crate::core::worker::Worker;
21use crate::cshadow as c;
22use crate::host::descriptor::listener::{StateEventSource, StateListenHandle, StateListenerFilter};
23use crate::host::descriptor::shared_buf::{
24 BufferHandle, BufferSignals, BufferState, ReaderHandle, SharedBuf,
25};
26use crate::host::descriptor::socket::{RecvmsgArgs, RecvmsgReturn, SendmsgArgs, Socket};
27use crate::host::descriptor::{
28 File, FileMode, FileSignals, FileState, FileStatus, OpenFile, SyscallResult,
29};
30use crate::host::memory_manager::MemoryManager;
31use crate::host::network::namespace::NetworkNamespace;
32use crate::host::syscall::io::{IoVec, IoVecReader, IoVecWriter};
33use crate::host::syscall::types::SyscallError;
34use crate::utility::HostTreePointer;
35use crate::utility::callback_queue::CallbackQueue;
36use crate::utility::sockaddr::SockaddrStorage;
37
38const NETLINK_SOCKET_DEFAULT_BUFFER_SIZE: u64 = 212_992;
40
41pub struct NetlinkSocket {
42 common: NetlinkSocketCommon,
44 protocol_state: ProtocolState,
46}
47
48impl NetlinkSocket {
49 pub fn new(
50 status: FileStatus,
51 _socket_type: NetlinkSocketType,
52 _family: NetlinkFamily,
53 ) -> Arc<AtomicRefCell<Self>> {
54 Arc::new_cyclic(|weak| {
55 let buffer = SharedBuf::new(usize::MAX);
57 let buffer = Arc::new(AtomicRefCell::new(buffer));
58
59 let default_ip = Worker::with_active_host(|host| host.default_ip()).unwrap();
61 let interfaces = vec![
63 Interface {
64 address: Ipv4Addr::LOCALHOST,
65 label: String::from("lo"),
66 prefix_len: 8,
67 if_type: Arphrd::Loopback,
68 mtu: c::CONFIG_MTU,
69 scope: RtScope::Host,
70 index: 1,
71 },
72 Interface {
73 address: default_ip,
74 label: String::from("eth0"),
75 prefix_len: 24,
76 if_type: Arphrd::Ether,
77 mtu: c::CONFIG_MTU,
78 scope: RtScope::Universe,
79 index: 2,
80 },
81 ];
82
83 let mut common = NetlinkSocketCommon {
84 buffer,
85 send_limit: NETLINK_SOCKET_DEFAULT_BUFFER_SIZE,
86 sent_len: 0,
87 event_source: StateEventSource::new(),
88 state: FileState::ACTIVE,
89 status,
90 has_open_file: false,
91 interfaces,
92 };
93 let protocol_state = ProtocolState::new(&mut common, weak);
94 let mut socket = Self {
95 common,
96 protocol_state,
97 };
98 CallbackQueue::queue_and_run_with_legacy(|cb_queue| {
99 socket.refresh_file_state(FileSignals::empty(), cb_queue)
100 });
101
102 AtomicRefCell::new(socket)
103 })
104 }
105
106 pub fn status(&self) -> FileStatus {
107 self.common.status
108 }
109
110 pub fn set_status(&mut self, status: FileStatus) {
111 self.common.status = status;
112 }
113
114 pub fn mode(&self) -> FileMode {
115 FileMode::READ | FileMode::WRITE
116 }
117
118 pub fn has_open_file(&self) -> bool {
119 self.common.has_open_file
120 }
121
122 pub fn supports_sa_restart(&self) -> bool {
123 self.common.supports_sa_restart()
124 }
125
126 pub fn set_has_open_file(&mut self, val: bool) {
127 self.common.has_open_file = val;
128 }
129
130 pub fn getsockname(&self) -> Result<Option<nix::sys::socket::NetlinkAddr>, Errno> {
131 self.protocol_state.bound_address()
132 }
133
134 pub fn getpeername(&self) -> Result<Option<nix::sys::socket::NetlinkAddr>, Errno> {
135 warn_once_then_debug!(
136 "getpeername() syscall not yet supported for netlink sockets; Returning ENOSYS"
137 );
138 Err(Errno::ENOSYS)
139 }
140
141 pub fn address_family(&self) -> linux_api::socket::AddressFamily {
142 linux_api::socket::AddressFamily::AF_NETLINK
143 }
144
145 pub fn close(&mut self, cb_queue: &mut CallbackQueue) -> Result<(), SyscallError> {
146 self.protocol_state.close(&mut self.common, cb_queue)
147 }
148
149 fn refresh_file_state(&mut self, signals: FileSignals, cb_queue: &mut CallbackQueue) {
150 self.protocol_state
151 .refresh_file_state(&mut self.common, signals, cb_queue)
152 }
153
154 pub fn shutdown(
155 &mut self,
156 _how: Shutdown,
157 _cb_queue: &mut CallbackQueue,
158 ) -> Result<(), SyscallError> {
159 warn_once_then_debug!(
160 "shutdown() syscall not yet supported for netlink sockets; Returning ENOSYS"
161 );
162 Err(Errno::ENOSYS.into())
163 }
164
165 pub fn getsockopt(
166 &mut self,
167 _level: libc::c_int,
168 _optname: libc::c_int,
169 _optval_ptr: ForeignPtr<()>,
170 _optlen: libc::socklen_t,
171 _memory_manager: &mut MemoryManager,
172 _cb_queue: &mut CallbackQueue,
173 ) -> Result<libc::socklen_t, SyscallError> {
174 warn_once_then_debug!(
175 "getsockopt() syscall not yet supported for netlink sockets; Returning ENOSYS"
176 );
177 Err(Errno::ENOSYS.into())
178 }
179
180 pub fn setsockopt(
181 &mut self,
182 level: libc::c_int,
183 optname: libc::c_int,
184 optval_ptr: ForeignPtr<()>,
185 optlen: libc::socklen_t,
186 memory_manager: &MemoryManager,
187 ) -> Result<(), SyscallError> {
188 match (level, optname) {
189 (libc::SOL_SOCKET, libc::SO_SNDBUF) => {
190 type OptType = libc::c_int;
191
192 if usize::try_from(optlen).unwrap() < std::mem::size_of::<OptType>() {
193 return Err(Errno::EINVAL.into());
194 }
195
196 let optval_ptr = optval_ptr.cast::<OptType>();
197 let val: u64 = memory_manager
198 .read(optval_ptr)?
199 .try_into()
200 .or(Err(Errno::EINVAL))?;
201
202 let val = val * 2;
204 let val = std::cmp::max(val, self.common.sent_len);
206 let val = std::cmp::max(val, 4096);
208 let val = std::cmp::min(val, 268435456); self.common.send_limit = val;
211 }
212 (libc::SOL_SOCKET, libc::SO_RCVBUF) => {
213 }
217 _ => {
218 warn_once_then_debug!(
219 "setsockopt called with unsupported level {level} and opt {optname}"
220 );
221 return Err(Errno::ENOPROTOOPT.into());
222 }
223 }
224
225 Ok(())
226 }
227
228 pub fn bind(
229 socket: &Arc<AtomicRefCell<Self>>,
230 addr: Option<&SockaddrStorage>,
231 _net_ns: &NetworkNamespace,
232 rng: impl rand::Rng,
233 ) -> Result<(), SyscallError> {
234 let socket_ref = &mut *socket.borrow_mut();
235 socket_ref
236 .protocol_state
237 .bind(&mut socket_ref.common, socket, addr, rng)
238 }
239
240 pub fn readv(
241 &mut self,
242 _iovs: &[IoVec],
243 _offset: Option<libc::off_t>,
244 _flags: libc::c_int,
245 _mem: &mut MemoryManager,
246 _cb_queue: &mut CallbackQueue,
247 ) -> Result<libc::ssize_t, SyscallError> {
248 panic!("Called NetlinkSocket::readv() on a netlink socket.");
252 }
253
254 pub fn writev(
255 &mut self,
256 _iovs: &[IoVec],
257 _offset: Option<libc::off_t>,
258 _flags: libc::c_int,
259 _mem: &mut MemoryManager,
260 _cb_queue: &mut CallbackQueue,
261 ) -> Result<libc::ssize_t, SyscallError> {
262 panic!("Called NetlinkSocket::writev() on a netlink socket");
266 }
267
268 pub fn sendmsg(
269 socket: &Arc<AtomicRefCell<Self>>,
270 args: SendmsgArgs,
271 mem: &mut MemoryManager,
272 _net_ns: &NetworkNamespace,
273 _rng: impl rand::Rng,
274 cb_queue: &mut CallbackQueue,
275 ) -> Result<libc::ssize_t, SyscallError> {
276 let socket_ref = &mut *socket.borrow_mut();
277 socket_ref
278 .protocol_state
279 .sendmsg(&mut socket_ref.common, socket, args, mem, cb_queue)
280 }
281
282 pub fn recvmsg(
283 socket: &Arc<AtomicRefCell<Self>>,
284 args: RecvmsgArgs,
285 mem: &mut MemoryManager,
286 cb_queue: &mut CallbackQueue,
287 ) -> Result<RecvmsgReturn, SyscallError> {
288 let socket_ref = &mut *socket.borrow_mut();
289 socket_ref
290 .protocol_state
291 .recvmsg(&mut socket_ref.common, socket, args, mem, cb_queue)
292 }
293
294 pub fn listen(
295 _socket: &Arc<AtomicRefCell<Self>>,
296 _backlog: i32,
297 _net_ns: &NetworkNamespace,
298 _rng: impl rand::Rng,
299 _cb_queue: &mut CallbackQueue,
300 ) -> Result<(), Errno> {
301 warn_once_then_debug!("We do not yet handle listen request on netlink sockets");
302 Err(Errno::EINVAL)
303 }
304
305 pub fn connect(
306 _socket: &Arc<AtomicRefCell<Self>>,
307 _addr: &SockaddrStorage,
308 _net_ns: &NetworkNamespace,
309 _rng: impl rand::Rng,
310 _cb_queue: &mut CallbackQueue,
311 ) -> Result<(), SyscallError> {
312 warn_once_then_debug!("We do not yet handle connect request on netlink sockets");
313 Err(Errno::EINVAL.into())
314 }
315
316 pub fn accept(
317 &mut self,
318 _net_ns: &NetworkNamespace,
319 _rng: impl rand::Rng,
320 _cb_queue: &mut CallbackQueue,
321 ) -> Result<OpenFile, SyscallError> {
322 warn_once_then_debug!("We do not yet handle accept request on netlink sockets");
323 Err(Errno::EINVAL.into())
324 }
325
326 pub fn ioctl(
327 &mut self,
328 request: IoctlRequest,
329 _arg_ptr: ForeignPtr<()>,
330 _memory_manager: &mut MemoryManager,
331 ) -> SyscallResult {
332 warn_once_then_debug!("We do not yet handle ioctl request {request:?} on netlink sockets");
333 Err(Errno::EINVAL.into())
334 }
335
336 pub fn stat(&self) -> Result<linux_api::stat::stat, SyscallError> {
337 warn_once_then_debug!("We do not yet handle stat calls on netlink sockets");
338 Err(Errno::EINVAL.into())
339 }
340
341 pub fn add_listener(
342 &mut self,
343 monitoring_state: FileState,
344 monitoring_signals: FileSignals,
345 filter: StateListenerFilter,
346 notify_fn: impl Fn(FileState, FileState, FileSignals, &mut CallbackQueue)
347 + Send
348 + Sync
349 + 'static,
350 ) -> StateListenHandle {
351 self.common.event_source.add_listener(
352 monitoring_state,
353 monitoring_signals,
354 filter,
355 notify_fn,
356 )
357 }
358
359 pub fn add_legacy_listener(&mut self, ptr: HostTreePointer<c::StatusListener>) {
360 self.common.event_source.add_legacy_listener(ptr);
361 }
362
363 pub fn remove_legacy_listener(&mut self, ptr: *mut c::StatusListener) {
364 self.common.event_source.remove_legacy_listener(ptr);
365 }
366
367 pub fn state(&self) -> FileState {
368 self.common.state
369 }
370}
371
372struct InitialState {
373 bound_addr: Option<NetlinkAddr>,
374 reader_handle: ReaderHandle,
375 _buffer_handle: BufferHandle,
377}
378struct ClosedState {}
379enum ProtocolState {
383 Initial(Option<InitialState>),
384 Closed(Option<ClosedState>),
385}
386
387macro_rules! state_upcast {
389 ($type:ty, $parent:ident::$variant:ident) => {
390 impl From<$type> for $parent {
391 fn from(x: $type) -> Self {
392 Self::$variant(Some(x))
393 }
394 }
395 };
396}
397
398state_upcast!(InitialState, ProtocolState::Initial);
400state_upcast!(ClosedState, ProtocolState::Closed);
401
402impl ProtocolState {
403 fn new(common: &mut NetlinkSocketCommon, socket: &Weak<AtomicRefCell<NetlinkSocket>>) -> Self {
404 let mut cb_queue = CallbackQueue::new();
406
407 let reader_handle = common.buffer.borrow_mut().add_reader(&mut cb_queue);
409
410 let weak = Weak::clone(socket);
411 let buffer_handle = common.buffer.borrow_mut().add_listener(
412 BufferState::READABLE,
413 BufferSignals::BUFFER_GREW,
414 move |_, signals, cb_queue| {
415 if let Some(socket) = weak.upgrade() {
416 let signals = if signals.contains(BufferSignals::BUFFER_GREW) {
417 FileSignals::READ_BUFFER_GREW
418 } else {
419 FileSignals::empty()
420 };
421
422 socket.borrow_mut().refresh_file_state(signals, cb_queue);
423 }
424 },
425 );
426
427 ProtocolState::Initial(Some(InitialState {
428 bound_addr: None,
429 reader_handle,
430 _buffer_handle: buffer_handle,
431 }))
432 }
433
434 fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
435 match self {
436 Self::Initial(x) => x.as_ref().unwrap().bound_address(),
437 Self::Closed(x) => x.as_ref().unwrap().bound_address(),
438 }
439 }
440
441 fn refresh_file_state(
442 &self,
443 common: &mut NetlinkSocketCommon,
444 signals: FileSignals,
445 cb_queue: &mut CallbackQueue,
446 ) {
447 match self {
448 Self::Initial(x) => x
449 .as_ref()
450 .unwrap()
451 .refresh_file_state(common, signals, cb_queue),
452 Self::Closed(x) => x
453 .as_ref()
454 .unwrap()
455 .refresh_file_state(common, signals, cb_queue),
456 }
457 }
458
459 fn close(
460 &mut self,
461 common: &mut NetlinkSocketCommon,
462 cb_queue: &mut CallbackQueue,
463 ) -> Result<(), SyscallError> {
464 let (new_state, rv) = match self {
465 Self::Initial(x) => x.take().unwrap().close(common, cb_queue),
466 Self::Closed(x) => x.take().unwrap().close(common, cb_queue),
467 };
468
469 *self = new_state;
470 rv
471 }
472
473 fn bind(
474 &mut self,
475 common: &mut NetlinkSocketCommon,
476 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
477 addr: Option<&SockaddrStorage>,
478 rng: impl rand::Rng,
479 ) -> Result<(), SyscallError> {
480 match self {
481 Self::Initial(x) => x.as_mut().unwrap().bind(common, socket, addr, rng),
482 Self::Closed(x) => x.as_mut().unwrap().bind(common, socket, addr, rng),
483 }
484 }
485
486 fn sendmsg(
487 &mut self,
488 common: &mut NetlinkSocketCommon,
489 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
490 args: SendmsgArgs,
491 mem: &mut MemoryManager,
492 cb_queue: &mut CallbackQueue,
493 ) -> Result<libc::ssize_t, SyscallError> {
494 match self {
495 Self::Initial(x) => x
496 .as_mut()
497 .unwrap()
498 .sendmsg(common, socket, args, mem, cb_queue),
499 Self::Closed(x) => x
500 .as_mut()
501 .unwrap()
502 .sendmsg(common, socket, args, mem, cb_queue),
503 }
504 }
505
506 fn recvmsg(
507 &mut self,
508 common: &mut NetlinkSocketCommon,
509 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
510 args: RecvmsgArgs,
511 mem: &mut MemoryManager,
512 cb_queue: &mut CallbackQueue,
513 ) -> Result<RecvmsgReturn, SyscallError> {
514 match self {
515 Self::Initial(x) => x
516 .as_mut()
517 .unwrap()
518 .recvmsg(common, socket, args, mem, cb_queue),
519 Self::Closed(x) => x
520 .as_mut()
521 .unwrap()
522 .recvmsg(common, socket, args, mem, cb_queue),
523 }
524 }
525}
526
527impl InitialState {
528 fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
529 Ok(self.bound_addr)
530 }
531
532 fn refresh_file_state(
533 &self,
534 common: &mut NetlinkSocketCommon,
535 signals: FileSignals,
536 cb_queue: &mut CallbackQueue,
537 ) {
538 let mut new_state = FileState::ACTIVE;
539
540 {
541 let buffer = common.buffer.borrow();
542
543 new_state.set(FileState::READABLE, buffer.has_data());
544 new_state.set(FileState::WRITABLE, common.sent_len < common.send_limit);
545 }
546
547 common.update_state(
548 FileState::all(),
549 new_state,
550 signals,
551 cb_queue,
552 );
553 }
554
555 fn close(
556 self,
557 common: &mut NetlinkSocketCommon,
558 cb_queue: &mut CallbackQueue,
559 ) -> (ProtocolState, Result<(), SyscallError>) {
560 common
562 .buffer
563 .borrow_mut()
564 .remove_reader(self.reader_handle, cb_queue);
565
566 let new_state = ClosedState {};
567 new_state.refresh_file_state(common, FileSignals::empty(), cb_queue);
568 (new_state.into(), Ok(()))
569 }
570
571 fn bind(
572 &mut self,
573 _common: &mut NetlinkSocketCommon,
574 _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
575 addr: Option<&SockaddrStorage>,
576 _rng: impl rand::Rng,
577 ) -> Result<(), SyscallError> {
578 if self.bound_addr.is_some() {
580 return Err(Errno::EINVAL.into());
581 }
582 if addr.is_none() {
584 return Err(Errno::EFAULT.into());
585 }
586
587 let Some(addr) = addr.and_then(|x| x.as_netlink()) else {
589 log::warn!("Attempted to bind netlink socket to non-netlink address {addr:?}");
590 return Err(Errno::EINVAL.into());
591 };
592
593 self.bound_addr = Some(*addr);
597
598 if (addr.groups() & !(RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR)) != 0 {
601 log::warn!(
602 "Attempted to bind netlink socket to an address with unsupported groups {}",
603 addr.groups()
604 );
605 return Err(Errno::EINVAL.into());
606 }
607
608 Ok(())
609 }
610
611 fn sendmsg(
612 &mut self,
613 common: &mut NetlinkSocketCommon,
614 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
615 args: SendmsgArgs,
616 mem: &mut MemoryManager,
617 cb_queue: &mut CallbackQueue,
618 ) -> Result<libc::ssize_t, SyscallError> {
619 if !args.control_ptr.ptr().is_null() {
620 log::debug!("Netlink sockets don't yet support control data for sendmsg()");
621 return Err(Errno::EINVAL.into());
622 }
623
624 if let Some(addr) = args.addr {
626 let Some(addr) = addr.as_netlink() else {
628 log::warn!("Attempted to send to non-netlink address {:?}", args.addr);
629 return Err(Errno::EINVAL.into());
630 };
631 if addr.pid() != 0 {
633 log::warn!("Attempted to send to non-kernel netlink address {addr:?}");
634 return Err(Errno::EINVAL.into());
635 }
636 if addr.groups() != 0 {
638 log::warn!("Attempted to send to netlink groups {addr:?}");
639 return Err(Errno::EINVAL.into());
640 }
641 }
642
643 let rv = common.sendmsg(socket, args.iovs, args.flags, mem, cb_queue)?;
644
645 self.refresh_file_state(common, FileSignals::empty(), cb_queue);
646
647 Ok(rv.try_into().unwrap())
648 }
649
650 fn recvmsg(
651 &mut self,
652 common: &mut NetlinkSocketCommon,
653 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
654 args: RecvmsgArgs,
655 mem: &mut MemoryManager,
656 cb_queue: &mut CallbackQueue,
657 ) -> Result<RecvmsgReturn, SyscallError> {
658 if !args.control_ptr.ptr().is_null() {
659 log::debug!("Netlink sockets don't yet support control data for recvmsg()");
660 return Err(Errno::EINVAL.into());
661 }
662 let Some(flags) = MsgFlags::from_bits(args.flags) else {
663 warn_once_then_debug!("Unrecognized recv flags: {:#b}", args.flags);
664 return Err(Errno::EINVAL.into());
665 };
666
667 let mut packet_buffer = Vec::new();
668 let (_rv, _num_removed_from_buf) =
669 common.recvmsg(socket, &mut packet_buffer, flags, mem, cb_queue)?;
670 self.refresh_file_state(common, FileSignals::empty(), cb_queue);
671
672 let mut writer = IoVecWriter::new(args.iovs, mem);
673
674 let src_addr = SockaddrStorage::from_netlink(&NetlinkAddr::new(0, 0));
676
677 if packet_buffer.len() < std::mem::size_of::<nlmsghdr>() {
678 log::warn!("The processed packet is too short");
679 return Err(Errno::EINVAL.into());
680 }
681
682 let buffer = {
683 let nlmsg_type = &packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_type)];
684 let nlmsg_type = u16::from_ne_bytes(nlmsg_type.try_into().unwrap());
685
686 match nlmsg_type {
687 RTM_GETLINK => {
688 let nlmsghdr_len = std::mem::size_of::<nlmsghdr>();
689 let ifinfomsg_len = std::mem::size_of::<ifinfomsg>();
690 let header_len = nlmsghdr_len + ifinfomsg_len;
691
692 if (nlmsghdr_len..header_len).contains(&packet_buffer.len()) {
698 log::debug!(
699 "Padding the RTM_GETLINK with zeroes to meet the minimum length"
700 );
701 packet_buffer.resize(header_len, 0);
702 packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_len)]
703 .copy_from_slice(&(header_len as u32).to_ne_bytes()[..]);
704 }
705 self.handle_ifinfomsg(common, &packet_buffer[..])
706 }
707 RTM_GETADDR => {
708 let nlmsghdr_len = std::mem::size_of::<nlmsghdr>();
709 let ifaddrmsg_len = std::mem::size_of::<ifaddrmsg>();
710 let header_len = nlmsghdr_len + ifaddrmsg_len;
711
712 if (nlmsghdr_len..header_len).contains(&packet_buffer.len()) {
718 log::debug!(
719 "Padding the RTM_GETADDR with zeroes to meet the minimum length"
720 );
721 packet_buffer.resize(header_len, 0);
722 packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_len)]
723 .copy_from_slice(&(header_len as u32).to_ne_bytes()[..]);
724 }
725 self.handle_ifaddrmsg(common, &packet_buffer[..])
726 }
727 _ => {
728 warn_once_then_debug!(
729 "Found unsupported nlmsg_type: {nlmsg_type} (only RTM_GETLINK
730 and RTM_GETADDR are supported)"
731 );
732 self.handle_error(&packet_buffer[..])
733 }
734 }
735 };
736
737 let mut total_copied = 0;
739 let mut buf = buffer.as_slice();
740 while !buf.is_empty() {
741 match writer.write(buf) {
742 Ok(0) => break,
743 Ok(n) => {
744 buf = &buf[n..];
745 total_copied += n;
746 }
747 Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
748 Err(e) => return Err(e.into()),
749 }
750 }
751
752 let return_val = if flags.contains(MsgFlags::MSG_TRUNC) {
753 buffer.len()
754 } else {
755 total_copied
756 };
757
758 Ok(RecvmsgReturn {
759 return_val: return_val.try_into().unwrap(),
760 addr: Some(src_addr),
761 msg_flags: 0,
762 control_len: 0,
763 })
764 }
765
766 fn handle_error(&self, bytes: &[u8]) -> Vec<u8> {
767 let nlmsg_seq = match bytes.get(memoffset::span_of!(nlmsghdr, nlmsg_seq)) {
769 Some(x) => u32::from_ne_bytes(x.try_into().unwrap()),
770 None => 0,
771 };
772
773 let msg = NlmsghdrBuilder::default()
775 .nl_type(Nlmsg::Error)
776 .nl_flags(NlmF::empty())
777 .nl_seq(nlmsg_seq)
778 .nl_payload(NlPayload::<Nlmsg, ()>::Empty)
779 .build()
780 .expect("NlmsghdrBuilder missing a required field");
781
782 let mut buffer = Cursor::new(Vec::new());
783 msg.to_bytes(&mut buffer).unwrap();
784 buffer.into_inner()
785 }
786
787 fn handle_ifaddrmsg(&self, common: &mut NetlinkSocketCommon, bytes: &[u8]) -> Vec<u8> {
788 let Ok(nlmsg) = Nlmsghdr::<Rtm, Ifaddrmsg>::from_bytes(&mut Cursor::new(bytes)) else {
789 log::warn!("Failed to deserialize the message");
790 return self.handle_error(bytes);
791 };
792
793 let Some(ifaddrmsg) = nlmsg.get_payload() else {
794 log::warn!("Failed to find the payload");
795 return self.handle_error(bytes);
796 };
797
798 if *ifaddrmsg.ifa_family() != RtAddrFamily::Unspecified
800 && *ifaddrmsg.ifa_family() != RtAddrFamily::Inet
801 {
802 log::warn!("Unsupported ifa_family (only AF_UNSPEC and AF_INET are supported)");
803 return self.handle_error(bytes);
804 }
805
806 if *ifaddrmsg.ifa_prefixlen() != 0
808 || !ifaddrmsg.ifa_flags().is_empty()
809 || *ifaddrmsg.ifa_index() != 0
810 || *ifaddrmsg.ifa_scope() != RtScope::Universe
811 {
812 log::warn!(
813 "Unsupported ifa_prefixlen, ifa_flags, ifa_scope, or ifa_index (they have to be 0)",
814 );
815 return self.handle_error(bytes);
816 }
817
818 let mut buffer = Cursor::new(Vec::new());
819 for interface in &common.interfaces {
821 let address = interface.address.octets();
822 let broadcast = Ipv4Addr::from(
823 0xffff_ffff_u32
824 .checked_shr(u32::from(interface.prefix_len))
825 .unwrap_or(0)
826 | u32::from(interface.address),
827 )
828 .octets();
829 let mut label = Vec::from(interface.label.as_bytes());
830 label.push(0); let attrs = [
834 RtattrBuilder::default()
838 .rta_type(Ifa::Address)
839 .rta_payload(Buffer::from(&address[..]))
840 .build()
841 .unwrap(),
842 RtattrBuilder::default()
843 .rta_type(Ifa::Local)
844 .rta_payload(Buffer::from(&address[..]))
845 .build()
846 .unwrap(),
847 RtattrBuilder::default()
848 .rta_type(Ifa::Broadcast)
849 .rta_payload(Buffer::from(&broadcast[..]))
850 .build()
851 .unwrap(),
852 RtattrBuilder::default()
853 .rta_type(Ifa::Label)
854 .rta_payload(Buffer::from(label))
855 .build()
856 .unwrap(),
857 ];
858 let ifaddrmsg = IfaddrmsgBuilder::default()
859 .ifa_family(RtAddrFamily::Inet)
860 .ifa_prefixlen(interface.prefix_len)
861 .ifa_flags(IfaF::PERMANENT)
863 .ifa_scope(interface.scope)
864 .ifa_index(interface.index)
865 .rtattrs(RtBuffer::from_iter(attrs))
866 .build()
867 .expect("IfaddrmsgBuilder missing a required field");
868 let nlmsg = NlmsghdrBuilder::default()
869 .nl_type(Rtm::Newaddr)
870 .nl_flags(NlmF::MULTI)
872 .nl_seq(*nlmsg.nl_seq())
874 .nl_payload(NlPayload::Payload(ifaddrmsg))
875 .build()
876 .expect("NlmsghdrBuilder missing a required field");
877 nlmsg.to_bytes(&mut buffer).unwrap();
878 }
879 let done_msg = NlmsghdrBuilder::default()
881 .nl_type(Nlmsg::Done)
882 .nl_flags(NlmF::MULTI)
883 .nl_seq(*nlmsg.nl_seq())
885 .nl_payload(NlPayload::Payload(0u32))
891 .build()
892 .expect("NlmsghdrBuilder missing a required field");
893 done_msg.to_bytes(&mut buffer).unwrap();
894
895 buffer.into_inner()
896 }
897
898 fn handle_ifinfomsg(&self, common: &mut NetlinkSocketCommon, bytes: &[u8]) -> Vec<u8> {
899 let Ok(nlmsg) = Nlmsghdr::<Rtm, Ifinfomsg>::from_bytes(&mut Cursor::new(bytes)) else {
900 log::warn!("Failed to deserialize the message");
901 return self.handle_error(bytes);
902 };
903
904 let Some(ifinfomsg) = nlmsg.get_payload() else {
905 log::warn!("Failed to find the payload");
906 return self.handle_error(bytes);
907 };
908
909 if *ifinfomsg.ifi_family() != RtAddrFamily::Unspecified
911 && *ifinfomsg.ifi_family() != RtAddrFamily::Inet
912 {
913 warn_once_then_debug!(
914 "Unsupported ifi_family (only AF_UNSPEC and AF_INET are supported)"
915 );
916 return self.handle_error(bytes);
917 }
918
919 if *ifinfomsg.ifi_type() != 0.into()
921 || *ifinfomsg.ifi_index() != 0
922 || !ifinfomsg.ifi_flags().is_empty()
923 {
924 warn_once_then_debug!(
925 "Unsupported ifi_type, ifi_index, or ifi_flags (they have to be 0)"
926 );
927 return self.handle_error(bytes);
928 }
929
930 let mut buffer = Cursor::new(Vec::new());
934 for interface in &common.interfaces {
936 let mut label = Vec::from(interface.label.as_bytes());
937 label.push(0); let attrs = [
941 RtattrBuilder::default()
942 .rta_type(Ifla::Ifname)
943 .rta_payload(Buffer::from(label))
944 .build()
945 .unwrap(),
946 RtattrBuilder::default()
950 .rta_type(Ifla::Txqlen)
951 .rta_payload(Buffer::from(&u32::to_le_bytes(1000)[..]))
952 .build()
953 .unwrap(),
954 RtattrBuilder::default()
955 .rta_type(Ifla::Mtu)
956 .rta_payload(Buffer::from(&u32::to_le_bytes(interface.mtu)[..]))
957 .build()
958 .unwrap(),
959 ];
961 let flags = if interface.if_type == Arphrd::Loopback {
962 Iff::UP | Iff::LOOPBACK | Iff::RUNNING
963 } else {
964 Iff::UP | Iff::BROADCAST | Iff::RUNNING | Iff::MULTICAST
966 };
967 let interface_index = interface
968 .index
969 .try_into()
970 .expect("interface index too large");
971
972 let ifinfomsg = IfinfomsgBuilder::default()
973 .ifi_family(RtAddrFamily::Inet)
974 .ifi_type(interface.if_type)
975 .ifi_index(interface_index)
976 .ifi_flags(flags)
977 .ifi_change(Iff::from_bits_retain(0xffffffff))
979 .rtattrs(RtBuffer::from_iter(attrs))
980 .build()
981 .expect("IfinfomsgBuilder missing a required field");
982 let nlmsg = NlmsghdrBuilder::default()
983 .nl_type(Rtm::Newlink)
984 .nl_flags(NlmF::MULTI)
986 .nl_seq(*nlmsg.nl_seq())
988 .nl_payload(NlPayload::Payload(ifinfomsg))
989 .build()
990 .expect("NlmsghdrBuilder missing a required field");
991 nlmsg.to_bytes(&mut buffer).unwrap();
992 }
993 let done_msg = NlmsghdrBuilder::default()
995 .nl_type(Nlmsg::Done)
996 .nl_flags(NlmF::MULTI)
997 .nl_seq(*nlmsg.nl_seq())
999 .nl_payload(NlPayload::Payload(0u32))
1005 .build()
1006 .expect("NlmsghdrBuilder missing a required field");
1007 done_msg.to_bytes(&mut buffer).unwrap();
1008
1009 buffer.into_inner()
1010 }
1011}
1012
1013impl ClosedState {
1014 fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
1015 Ok(None)
1016 }
1017
1018 fn refresh_file_state(
1019 &self,
1020 common: &mut NetlinkSocketCommon,
1021 signals: FileSignals,
1022 cb_queue: &mut CallbackQueue,
1023 ) {
1024 common.update_state(
1025 FileState::all(),
1026 FileState::CLOSED,
1027 signals,
1028 cb_queue,
1029 );
1030 }
1031
1032 fn close(
1033 self,
1034 _common: &mut NetlinkSocketCommon,
1035 _cb_queue: &mut CallbackQueue,
1036 ) -> (ProtocolState, Result<(), SyscallError>) {
1037 panic!("Trying to close an already closed socket");
1039 }
1040
1041 fn bind(
1042 &mut self,
1043 _common: &mut NetlinkSocketCommon,
1044 _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1045 _addr: Option<&SockaddrStorage>,
1046 _rng: impl rand::Rng,
1047 ) -> Result<(), SyscallError> {
1048 log::warn!("bind() while in state {}", std::any::type_name::<Self>());
1050 Err(Errno::EOPNOTSUPP.into())
1051 }
1052
1053 fn sendmsg(
1054 &mut self,
1055 _common: &mut NetlinkSocketCommon,
1056 _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1057 _args: SendmsgArgs,
1058 _mem: &mut MemoryManager,
1059 _cb_queue: &mut CallbackQueue,
1060 ) -> Result<libc::ssize_t, SyscallError> {
1061 log::warn!("sendmsg() while in state {}", std::any::type_name::<Self>());
1063 Err(Errno::EOPNOTSUPP.into())
1064 }
1065
1066 fn recvmsg(
1067 &mut self,
1068 _common: &mut NetlinkSocketCommon,
1069 _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1070 _args: RecvmsgArgs,
1071 _mem: &mut MemoryManager,
1072 _cb_queue: &mut CallbackQueue,
1073 ) -> Result<RecvmsgReturn, SyscallError> {
1074 log::warn!("recvmsg() while in state {}", std::any::type_name::<Self>());
1076 Err(Errno::EOPNOTSUPP.into())
1077 }
1078}
1079
1080struct Interface {
1082 address: Ipv4Addr,
1083 label: String,
1084 prefix_len: u8,
1085 if_type: Arphrd,
1086 mtu: u32,
1087 scope: RtScope,
1088 index: libc::c_uint,
1089}
1090
1091struct NetlinkSocketCommon {
1093 buffer: Arc<AtomicRefCell<SharedBuf>>,
1094 send_limit: u64,
1096 sent_len: u64,
1098 event_source: StateEventSource,
1099 state: FileState,
1100 status: FileStatus,
1101 has_open_file: bool,
1104 interfaces: Vec<Interface>,
1106}
1107
1108impl NetlinkSocketCommon {
1109 pub fn supports_sa_restart(&self) -> bool {
1110 true
1111 }
1112
1113 pub fn sendmsg(
1114 &mut self,
1115 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1116 iovs: &[IoVec],
1117 flags: libc::c_int,
1118 mem: &mut MemoryManager,
1119 cb_queue: &mut CallbackQueue,
1120 ) -> Result<usize, SyscallError> {
1121 let supported_flags = MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_NOSIGNAL | MsgFlags::MSG_TRUNC;
1124
1125 let Some(mut flags) = MsgFlags::from_bits(flags) else {
1128 warn_once_then_debug!("Unrecognized send flags: {:#b}", flags);
1129 return Err(Errno::EINVAL.into());
1130 };
1131 if flags.intersects(!supported_flags) {
1132 warn_once_then_debug!("Unsupported send flags: {:?}", flags);
1133 return Err(Errno::EINVAL.into());
1134 }
1135
1136 if self.status.contains(FileStatus::NONBLOCK) {
1137 flags.insert(MsgFlags::MSG_DONTWAIT);
1138 }
1139
1140 let result = (|| {
1142 let len = iovs.iter().map(|x| x.len).sum::<libc::size_t>();
1143
1144 let space_available = self
1147 .send_limit
1148 .saturating_sub(self.sent_len)
1149 .try_into()
1150 .unwrap();
1151
1152 if space_available == 0 {
1153 return Err(Errno::EAGAIN);
1154 }
1155
1156 if len > space_available {
1157 if len <= self.send_limit.try_into().unwrap() {
1158 return Err(Errno::EAGAIN);
1160 } else {
1161 return Err(Errno::EMSGSIZE);
1163 }
1164 }
1165
1166 let reader = IoVecReader::new(iovs, mem);
1167 let reader = reader.take(len.try_into().unwrap());
1168
1169 self.buffer
1172 .borrow_mut()
1173 .write_packet(reader, len, cb_queue)
1174 .map_err(|e| Errno::try_from(e).unwrap())?;
1175
1176 self.sent_len += u64::try_from(len).unwrap();
1178 Ok(len)
1179 })();
1180
1181 if result.as_ref().err() == Some(&Errno::EWOULDBLOCK)
1183 && !flags.contains(MsgFlags::MSG_DONTWAIT)
1184 {
1185 return Err(SyscallError::new_blocked_on_file(
1186 File::Socket(Socket::Netlink(socket.clone())),
1187 FileState::WRITABLE,
1188 self.supports_sa_restart(),
1189 ));
1190 }
1191
1192 Ok(result?)
1193 }
1194
1195 pub fn recvmsg<W: Write>(
1196 &mut self,
1197 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1198 dst: W,
1199 mut flags: MsgFlags,
1200 _mem: &mut MemoryManager,
1201 cb_queue: &mut CallbackQueue,
1202 ) -> Result<(usize, usize), SyscallError> {
1203 let supported_flags = MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_PEEK | MsgFlags::MSG_TRUNC;
1204
1205 if flags.intersects(!supported_flags) {
1208 warn_once_then_debug!("Unsupported recv flags: {:?}", flags);
1209 return Err(Errno::EINVAL.into());
1210 }
1211
1212 if self.status.contains(FileStatus::NONBLOCK) {
1213 flags.insert(MsgFlags::MSG_DONTWAIT);
1214 }
1215
1216 let result = (|| {
1218 let mut buffer = self.buffer.borrow_mut();
1219
1220 if !buffer.has_data() {
1222 return Err(Errno::EWOULDBLOCK);
1223 }
1224
1225 let (num_copied, num_removed_from_buf) = if flags.contains(MsgFlags::MSG_PEEK) {
1226 buffer.peek(dst).map_err(|e| Errno::try_from(e).unwrap())?
1227 } else {
1228 buffer
1229 .read(dst, cb_queue)
1230 .map_err(|e| Errno::try_from(e).unwrap())?
1231 };
1232
1233 if flags.contains(MsgFlags::MSG_TRUNC) {
1234 Ok((num_removed_from_buf, num_removed_from_buf))
1236 } else {
1237 Ok((num_copied, num_removed_from_buf))
1238 }
1239 })();
1240
1241 if result.as_ref().err() == Some(&Errno::EWOULDBLOCK)
1243 && !flags.contains(MsgFlags::MSG_DONTWAIT)
1244 {
1245 return Err(SyscallError::new_blocked_on_file(
1246 File::Socket(Socket::Netlink(socket.clone())),
1247 FileState::READABLE,
1248 self.supports_sa_restart(),
1249 ));
1250 }
1251
1252 Ok(result?)
1253 }
1254
1255 fn update_state(
1256 &mut self,
1257 mask: FileState,
1258 state: FileState,
1259 signals: FileSignals,
1260 cb_queue: &mut CallbackQueue,
1261 ) {
1262 let old_state = self.state;
1263
1264 self.state.remove(mask);
1266 self.state.insert(state & mask);
1267
1268 self.handle_state_change(old_state, signals, cb_queue);
1269 }
1270
1271 fn handle_state_change(
1272 &mut self,
1273 old_state: FileState,
1274 signals: FileSignals,
1275 cb_queue: &mut CallbackQueue,
1276 ) {
1277 let states_changed = self.state ^ old_state;
1278
1279 if states_changed.is_empty() && signals.is_empty() {
1281 return;
1282 }
1283
1284 self.event_source
1285 .notify_listeners(self.state, states_changed, signals, cb_queue);
1286 }
1287}
1288
1289#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
1290pub enum NetlinkSocketType {
1291 Dgram,
1292 Raw,
1293}
1294
1295impl TryFrom<libc::c_int> for NetlinkSocketType {
1296 type Error = NetlinkSocketTypeConversionError;
1297 fn try_from(val: libc::c_int) -> Result<Self, Self::Error> {
1298 match val {
1299 libc::SOCK_DGRAM => Ok(Self::Dgram),
1300 libc::SOCK_RAW => Ok(Self::Raw),
1301 x => Err(NetlinkSocketTypeConversionError(x)),
1302 }
1303 }
1304}
1305
1306#[derive(Copy, Clone, Debug)]
1307pub struct NetlinkSocketTypeConversionError(libc::c_int);
1308
1309impl std::error::Error for NetlinkSocketTypeConversionError {}
1310
1311impl std::fmt::Display for NetlinkSocketTypeConversionError {
1312 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1313 write!(
1314 f,
1315 "Invalid socket type {}; netlink sockets only support SOCK_DGRAM and SOCK_RAW",
1316 self.0
1317 )
1318 }
1319}
1320
1321#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
1322pub enum NetlinkFamily {
1323 Route,
1324}
1325
1326impl TryFrom<libc::c_int> for NetlinkFamily {
1327 type Error = NetlinkFamilyConversionError;
1328 fn try_from(val: libc::c_int) -> Result<Self, Self::Error> {
1329 match val {
1330 libc::NETLINK_ROUTE => Ok(Self::Route),
1331 x => Err(NetlinkFamilyConversionError(x)),
1332 }
1333 }
1334}
1335
1336#[derive(Copy, Clone, Debug)]
1337pub struct NetlinkFamilyConversionError(libc::c_int);
1338
1339impl std::error::Error for NetlinkFamilyConversionError {}
1340
1341impl std::fmt::Display for NetlinkFamilyConversionError {
1342 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1343 write!(
1344 f,
1345 "Invalid netlink family {}; netlink families only support NETLINK_ROUTE",
1346 self.0
1347 )
1348 }
1349}