1use std::io::{Cursor, ErrorKind, Read, Write};
2use std::net::Ipv4Addr;
3use std::sync::{Arc, Weak};
4
5use atomic_refcell::AtomicRefCell;
6use linux_api::errno::Errno;
7use linux_api::ioctls::IoctlRequest;
8use linux_api::netlink::{ifaddrmsg, ifinfomsg, nlmsghdr};
9use linux_api::rtnetlink::{RTM_GETADDR, RTM_GETLINK, RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR};
10use linux_api::socket::Shutdown;
11use neli::consts::nl::{NlmF, NlmFFlags, Nlmsg};
12use neli::consts::rtnl::{
13 Arphrd, Ifa, IfaF, IfaFFlags, Iff, IffFlags, Ifla, RtAddrFamily, RtScope, Rtm,
14};
15use neli::nl::{NlPayload, Nlmsghdr};
16use neli::rtnl::{Ifaddrmsg, Ifinfomsg, Rtattr};
17use neli::types::{Buffer, RtBuffer};
18use neli::{FromBytes, ToBytes};
19use nix::sys::socket::{MsgFlags, NetlinkAddr};
20use shadow_shim_helper_rs::syscall_types::ForeignPtr;
21
22use crate::core::worker::Worker;
23use crate::cshadow as c;
24use crate::host::descriptor::listener::{StateEventSource, StateListenHandle, StateListenerFilter};
25use crate::host::descriptor::shared_buf::{
26 BufferHandle, BufferSignals, BufferState, ReaderHandle, SharedBuf,
27};
28use crate::host::descriptor::socket::{RecvmsgArgs, RecvmsgReturn, SendmsgArgs, Socket};
29use crate::host::descriptor::{
30 File, FileMode, FileSignals, FileState, FileStatus, OpenFile, SyscallResult,
31};
32use crate::host::memory_manager::MemoryManager;
33use crate::host::network::namespace::NetworkNamespace;
34use crate::host::syscall::io::{IoVec, IoVecReader, IoVecWriter};
35use crate::host::syscall::types::SyscallError;
36use crate::utility::HostTreePointer;
37use crate::utility::callback_queue::CallbackQueue;
38use crate::utility::sockaddr::SockaddrStorage;
39
40const NETLINK_SOCKET_DEFAULT_BUFFER_SIZE: u64 = 212_992;
42
43pub struct NetlinkSocket {
44 common: NetlinkSocketCommon,
46 protocol_state: ProtocolState,
48}
49
50impl NetlinkSocket {
51 pub fn new(
52 status: FileStatus,
53 _socket_type: NetlinkSocketType,
54 _family: NetlinkFamily,
55 ) -> Arc<AtomicRefCell<Self>> {
56 Arc::new_cyclic(|weak| {
57 let buffer = SharedBuf::new(usize::MAX);
59 let buffer = Arc::new(AtomicRefCell::new(buffer));
60
61 let default_ip = Worker::with_active_host(|host| host.default_ip()).unwrap();
63 let interfaces = vec![
65 Interface {
66 address: Ipv4Addr::LOCALHOST,
67 label: String::from("lo"),
68 prefix_len: 8,
69 if_type: Arphrd::Loopback,
70 mtu: c::CONFIG_MTU,
71 scope: RtScope::Host,
72 index: 1,
73 },
74 Interface {
75 address: default_ip,
76 label: String::from("eth0"),
77 prefix_len: 24,
78 if_type: Arphrd::Ether,
79 mtu: c::CONFIG_MTU,
80 scope: RtScope::Universe,
81 index: 2,
82 },
83 ];
84
85 let mut common = NetlinkSocketCommon {
86 buffer,
87 send_limit: NETLINK_SOCKET_DEFAULT_BUFFER_SIZE,
88 sent_len: 0,
89 event_source: StateEventSource::new(),
90 state: FileState::ACTIVE,
91 status,
92 has_open_file: false,
93 interfaces,
94 };
95 let protocol_state = ProtocolState::new(&mut common, weak);
96 let mut socket = Self {
97 common,
98 protocol_state,
99 };
100 CallbackQueue::queue_and_run_with_legacy(|cb_queue| {
101 socket.refresh_file_state(FileSignals::empty(), cb_queue)
102 });
103
104 AtomicRefCell::new(socket)
105 })
106 }
107
108 pub fn status(&self) -> FileStatus {
109 self.common.status
110 }
111
112 pub fn set_status(&mut self, status: FileStatus) {
113 self.common.status = status;
114 }
115
116 pub fn mode(&self) -> FileMode {
117 FileMode::READ | FileMode::WRITE
118 }
119
120 pub fn has_open_file(&self) -> bool {
121 self.common.has_open_file
122 }
123
124 pub fn supports_sa_restart(&self) -> bool {
125 self.common.supports_sa_restart()
126 }
127
128 pub fn set_has_open_file(&mut self, val: bool) {
129 self.common.has_open_file = val;
130 }
131
132 pub fn getsockname(&self) -> Result<Option<nix::sys::socket::NetlinkAddr>, Errno> {
133 self.protocol_state.bound_address()
134 }
135
136 pub fn getpeername(&self) -> Result<Option<nix::sys::socket::NetlinkAddr>, Errno> {
137 warn_once_then_debug!(
138 "getpeername() syscall not yet supported for netlink sockets; Returning ENOSYS"
139 );
140 Err(Errno::ENOSYS)
141 }
142
143 pub fn address_family(&self) -> linux_api::socket::AddressFamily {
144 linux_api::socket::AddressFamily::AF_NETLINK
145 }
146
147 pub fn close(&mut self, cb_queue: &mut CallbackQueue) -> Result<(), SyscallError> {
148 self.protocol_state.close(&mut self.common, cb_queue)
149 }
150
151 fn refresh_file_state(&mut self, signals: FileSignals, cb_queue: &mut CallbackQueue) {
152 self.protocol_state
153 .refresh_file_state(&mut self.common, signals, cb_queue)
154 }
155
156 pub fn shutdown(
157 &mut self,
158 _how: Shutdown,
159 _cb_queue: &mut CallbackQueue,
160 ) -> Result<(), SyscallError> {
161 warn_once_then_debug!(
162 "shutdown() syscall not yet supported for netlink sockets; Returning ENOSYS"
163 );
164 Err(Errno::ENOSYS.into())
165 }
166
167 pub fn getsockopt(
168 &mut self,
169 _level: libc::c_int,
170 _optname: libc::c_int,
171 _optval_ptr: ForeignPtr<()>,
172 _optlen: libc::socklen_t,
173 _memory_manager: &mut MemoryManager,
174 _cb_queue: &mut CallbackQueue,
175 ) -> Result<libc::socklen_t, SyscallError> {
176 warn_once_then_debug!(
177 "getsockopt() syscall not yet supported for netlink sockets; Returning ENOSYS"
178 );
179 Err(Errno::ENOSYS.into())
180 }
181
182 pub fn setsockopt(
183 &mut self,
184 level: libc::c_int,
185 optname: libc::c_int,
186 optval_ptr: ForeignPtr<()>,
187 optlen: libc::socklen_t,
188 memory_manager: &MemoryManager,
189 ) -> Result<(), SyscallError> {
190 match (level, optname) {
191 (libc::SOL_SOCKET, libc::SO_SNDBUF) => {
192 type OptType = libc::c_int;
193
194 if usize::try_from(optlen).unwrap() < std::mem::size_of::<OptType>() {
195 return Err(Errno::EINVAL.into());
196 }
197
198 let optval_ptr = optval_ptr.cast::<OptType>();
199 let val: u64 = memory_manager
200 .read(optval_ptr)?
201 .try_into()
202 .or(Err(Errno::EINVAL))?;
203
204 let val = val * 2;
206 let val = std::cmp::max(val, self.common.sent_len);
208 let val = std::cmp::max(val, 4096);
210 let val = std::cmp::min(val, 268435456); self.common.send_limit = val;
213 }
214 (libc::SOL_SOCKET, libc::SO_RCVBUF) => {
215 }
219 _ => {
220 warn_once_then_debug!(
221 "setsockopt called with unsupported level {level} and opt {optname}"
222 );
223 return Err(Errno::ENOPROTOOPT.into());
224 }
225 }
226
227 Ok(())
228 }
229
230 pub fn bind(
231 socket: &Arc<AtomicRefCell<Self>>,
232 addr: Option<&SockaddrStorage>,
233 _net_ns: &NetworkNamespace,
234 rng: impl rand::Rng,
235 ) -> Result<(), SyscallError> {
236 let socket_ref = &mut *socket.borrow_mut();
237 socket_ref
238 .protocol_state
239 .bind(&mut socket_ref.common, socket, addr, rng)
240 }
241
242 pub fn readv(
243 &mut self,
244 _iovs: &[IoVec],
245 _offset: Option<libc::off_t>,
246 _flags: libc::c_int,
247 _mem: &mut MemoryManager,
248 _cb_queue: &mut CallbackQueue,
249 ) -> Result<libc::ssize_t, SyscallError> {
250 panic!("Called NetlinkSocket::readv() on a netlink socket.");
254 }
255
256 pub fn writev(
257 &mut self,
258 _iovs: &[IoVec],
259 _offset: Option<libc::off_t>,
260 _flags: libc::c_int,
261 _mem: &mut MemoryManager,
262 _cb_queue: &mut CallbackQueue,
263 ) -> Result<libc::ssize_t, SyscallError> {
264 panic!("Called NetlinkSocket::writev() on a netlink socket");
268 }
269
270 pub fn sendmsg(
271 socket: &Arc<AtomicRefCell<Self>>,
272 args: SendmsgArgs,
273 mem: &mut MemoryManager,
274 _net_ns: &NetworkNamespace,
275 _rng: impl rand::Rng,
276 cb_queue: &mut CallbackQueue,
277 ) -> Result<libc::ssize_t, SyscallError> {
278 let socket_ref = &mut *socket.borrow_mut();
279 socket_ref
280 .protocol_state
281 .sendmsg(&mut socket_ref.common, socket, args, mem, cb_queue)
282 }
283
284 pub fn recvmsg(
285 socket: &Arc<AtomicRefCell<Self>>,
286 args: RecvmsgArgs,
287 mem: &mut MemoryManager,
288 cb_queue: &mut CallbackQueue,
289 ) -> Result<RecvmsgReturn, SyscallError> {
290 let socket_ref = &mut *socket.borrow_mut();
291 socket_ref
292 .protocol_state
293 .recvmsg(&mut socket_ref.common, socket, args, mem, cb_queue)
294 }
295
296 pub fn listen(
297 _socket: &Arc<AtomicRefCell<Self>>,
298 _backlog: i32,
299 _net_ns: &NetworkNamespace,
300 _rng: impl rand::Rng,
301 _cb_queue: &mut CallbackQueue,
302 ) -> Result<(), Errno> {
303 warn_once_then_debug!("We do not yet handle listen request on netlink sockets");
304 Err(Errno::EINVAL)
305 }
306
307 pub fn connect(
308 _socket: &Arc<AtomicRefCell<Self>>,
309 _addr: &SockaddrStorage,
310 _net_ns: &NetworkNamespace,
311 _rng: impl rand::Rng,
312 _cb_queue: &mut CallbackQueue,
313 ) -> Result<(), SyscallError> {
314 warn_once_then_debug!("We do not yet handle connect request on netlink sockets");
315 Err(Errno::EINVAL.into())
316 }
317
318 pub fn accept(
319 &mut self,
320 _net_ns: &NetworkNamespace,
321 _rng: impl rand::Rng,
322 _cb_queue: &mut CallbackQueue,
323 ) -> Result<OpenFile, SyscallError> {
324 warn_once_then_debug!("We do not yet handle accept request on netlink sockets");
325 Err(Errno::EINVAL.into())
326 }
327
328 pub fn ioctl(
329 &mut self,
330 request: IoctlRequest,
331 _arg_ptr: ForeignPtr<()>,
332 _memory_manager: &mut MemoryManager,
333 ) -> SyscallResult {
334 warn_once_then_debug!("We do not yet handle ioctl request {request:?} on netlink sockets");
335 Err(Errno::EINVAL.into())
336 }
337
338 pub fn stat(&self) -> Result<linux_api::stat::stat, SyscallError> {
339 warn_once_then_debug!("We do not yet handle stat calls on netlink sockets");
340 Err(Errno::EINVAL.into())
341 }
342
343 pub fn add_listener(
344 &mut self,
345 monitoring_state: FileState,
346 monitoring_signals: FileSignals,
347 filter: StateListenerFilter,
348 notify_fn: impl Fn(FileState, FileState, FileSignals, &mut CallbackQueue)
349 + Send
350 + Sync
351 + 'static,
352 ) -> StateListenHandle {
353 self.common.event_source.add_listener(
354 monitoring_state,
355 monitoring_signals,
356 filter,
357 notify_fn,
358 )
359 }
360
361 pub fn add_legacy_listener(&mut self, ptr: HostTreePointer<c::StatusListener>) {
362 self.common.event_source.add_legacy_listener(ptr);
363 }
364
365 pub fn remove_legacy_listener(&mut self, ptr: *mut c::StatusListener) {
366 self.common.event_source.remove_legacy_listener(ptr);
367 }
368
369 pub fn state(&self) -> FileState {
370 self.common.state
371 }
372}
373
374struct InitialState {
375 bound_addr: Option<NetlinkAddr>,
376 reader_handle: ReaderHandle,
377 _buffer_handle: BufferHandle,
379}
380struct ClosedState {}
381enum ProtocolState {
385 Initial(Option<InitialState>),
386 Closed(Option<ClosedState>),
387}
388
389macro_rules! state_upcast {
391 ($type:ty, $parent:ident::$variant:ident) => {
392 impl From<$type> for $parent {
393 fn from(x: $type) -> Self {
394 Self::$variant(Some(x))
395 }
396 }
397 };
398}
399
400state_upcast!(InitialState, ProtocolState::Initial);
402state_upcast!(ClosedState, ProtocolState::Closed);
403
404impl ProtocolState {
405 fn new(common: &mut NetlinkSocketCommon, socket: &Weak<AtomicRefCell<NetlinkSocket>>) -> Self {
406 let mut cb_queue = CallbackQueue::new();
408
409 let reader_handle = common.buffer.borrow_mut().add_reader(&mut cb_queue);
411
412 let weak = Weak::clone(socket);
413 let buffer_handle = common.buffer.borrow_mut().add_listener(
414 BufferState::READABLE,
415 BufferSignals::BUFFER_GREW,
416 move |_, signals, cb_queue| {
417 if let Some(socket) = weak.upgrade() {
418 let signals = if signals.contains(BufferSignals::BUFFER_GREW) {
419 FileSignals::READ_BUFFER_GREW
420 } else {
421 FileSignals::empty()
422 };
423
424 socket.borrow_mut().refresh_file_state(signals, cb_queue);
425 }
426 },
427 );
428
429 ProtocolState::Initial(Some(InitialState {
430 bound_addr: None,
431 reader_handle,
432 _buffer_handle: buffer_handle,
433 }))
434 }
435
436 fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
437 match self {
438 Self::Initial(x) => x.as_ref().unwrap().bound_address(),
439 Self::Closed(x) => x.as_ref().unwrap().bound_address(),
440 }
441 }
442
443 fn refresh_file_state(
444 &self,
445 common: &mut NetlinkSocketCommon,
446 signals: FileSignals,
447 cb_queue: &mut CallbackQueue,
448 ) {
449 match self {
450 Self::Initial(x) => x
451 .as_ref()
452 .unwrap()
453 .refresh_file_state(common, signals, cb_queue),
454 Self::Closed(x) => x
455 .as_ref()
456 .unwrap()
457 .refresh_file_state(common, signals, cb_queue),
458 }
459 }
460
461 fn close(
462 &mut self,
463 common: &mut NetlinkSocketCommon,
464 cb_queue: &mut CallbackQueue,
465 ) -> Result<(), SyscallError> {
466 let (new_state, rv) = match self {
467 Self::Initial(x) => x.take().unwrap().close(common, cb_queue),
468 Self::Closed(x) => x.take().unwrap().close(common, cb_queue),
469 };
470
471 *self = new_state;
472 rv
473 }
474
475 fn bind(
476 &mut self,
477 common: &mut NetlinkSocketCommon,
478 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
479 addr: Option<&SockaddrStorage>,
480 rng: impl rand::Rng,
481 ) -> Result<(), SyscallError> {
482 match self {
483 Self::Initial(x) => x.as_mut().unwrap().bind(common, socket, addr, rng),
484 Self::Closed(x) => x.as_mut().unwrap().bind(common, socket, addr, rng),
485 }
486 }
487
488 fn sendmsg(
489 &mut self,
490 common: &mut NetlinkSocketCommon,
491 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
492 args: SendmsgArgs,
493 mem: &mut MemoryManager,
494 cb_queue: &mut CallbackQueue,
495 ) -> Result<libc::ssize_t, SyscallError> {
496 match self {
497 Self::Initial(x) => x
498 .as_mut()
499 .unwrap()
500 .sendmsg(common, socket, args, mem, cb_queue),
501 Self::Closed(x) => x
502 .as_mut()
503 .unwrap()
504 .sendmsg(common, socket, args, mem, cb_queue),
505 }
506 }
507
508 fn recvmsg(
509 &mut self,
510 common: &mut NetlinkSocketCommon,
511 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
512 args: RecvmsgArgs,
513 mem: &mut MemoryManager,
514 cb_queue: &mut CallbackQueue,
515 ) -> Result<RecvmsgReturn, SyscallError> {
516 match self {
517 Self::Initial(x) => x
518 .as_mut()
519 .unwrap()
520 .recvmsg(common, socket, args, mem, cb_queue),
521 Self::Closed(x) => x
522 .as_mut()
523 .unwrap()
524 .recvmsg(common, socket, args, mem, cb_queue),
525 }
526 }
527}
528
529impl InitialState {
530 fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
531 Ok(self.bound_addr)
532 }
533
534 fn refresh_file_state(
535 &self,
536 common: &mut NetlinkSocketCommon,
537 signals: FileSignals,
538 cb_queue: &mut CallbackQueue,
539 ) {
540 let mut new_state = FileState::ACTIVE;
541
542 {
543 let buffer = common.buffer.borrow();
544
545 new_state.set(FileState::READABLE, buffer.has_data());
546 new_state.set(FileState::WRITABLE, common.sent_len < common.send_limit);
547 }
548
549 common.update_state(
550 FileState::all(),
551 new_state,
552 signals,
553 cb_queue,
554 );
555 }
556
557 fn close(
558 self,
559 common: &mut NetlinkSocketCommon,
560 cb_queue: &mut CallbackQueue,
561 ) -> (ProtocolState, Result<(), SyscallError>) {
562 common
564 .buffer
565 .borrow_mut()
566 .remove_reader(self.reader_handle, cb_queue);
567
568 let new_state = ClosedState {};
569 new_state.refresh_file_state(common, FileSignals::empty(), cb_queue);
570 (new_state.into(), Ok(()))
571 }
572
573 fn bind(
574 &mut self,
575 _common: &mut NetlinkSocketCommon,
576 _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
577 addr: Option<&SockaddrStorage>,
578 _rng: impl rand::Rng,
579 ) -> Result<(), SyscallError> {
580 if self.bound_addr.is_some() {
582 return Err(Errno::EINVAL.into());
583 }
584 if addr.is_none() {
586 return Err(Errno::EFAULT.into());
587 }
588
589 let Some(addr) = addr.and_then(|x| x.as_netlink()) else {
591 log::warn!(
592 "Attempted to bind netlink socket to non-netlink address {:?}",
593 addr
594 );
595 return Err(Errno::EINVAL.into());
596 };
597
598 self.bound_addr = Some(*addr);
602
603 if (addr.groups() & !(RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR)) != 0 {
606 log::warn!(
607 "Attempted to bind netlink socket to an address with unsupported groups {}",
608 addr.groups()
609 );
610 return Err(Errno::EINVAL.into());
611 }
612
613 Ok(())
614 }
615
616 fn sendmsg(
617 &mut self,
618 common: &mut NetlinkSocketCommon,
619 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
620 args: SendmsgArgs,
621 mem: &mut MemoryManager,
622 cb_queue: &mut CallbackQueue,
623 ) -> Result<libc::ssize_t, SyscallError> {
624 if !args.control_ptr.ptr().is_null() {
625 log::debug!("Netlink sockets don't yet support control data for sendmsg()");
626 return Err(Errno::EINVAL.into());
627 }
628
629 if let Some(addr) = args.addr {
631 let Some(addr) = addr.as_netlink() else {
633 log::warn!("Attempted to send to non-netlink address {:?}", args.addr);
634 return Err(Errno::EINVAL.into());
635 };
636 if addr.pid() != 0 {
638 log::warn!("Attempted to send to non-kernel netlink address {:?}", addr);
639 return Err(Errno::EINVAL.into());
640 }
641 if addr.groups() != 0 {
643 log::warn!("Attempted to send to netlink groups {:?}", addr);
644 return Err(Errno::EINVAL.into());
645 }
646 }
647
648 let rv = common.sendmsg(socket, args.iovs, args.flags, mem, cb_queue)?;
649
650 self.refresh_file_state(common, FileSignals::empty(), cb_queue);
651
652 Ok(rv.try_into().unwrap())
653 }
654
655 fn recvmsg(
656 &mut self,
657 common: &mut NetlinkSocketCommon,
658 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
659 args: RecvmsgArgs,
660 mem: &mut MemoryManager,
661 cb_queue: &mut CallbackQueue,
662 ) -> Result<RecvmsgReturn, SyscallError> {
663 if !args.control_ptr.ptr().is_null() {
664 log::debug!("Netlink sockets don't yet support control data for recvmsg()");
665 return Err(Errno::EINVAL.into());
666 }
667 let Some(flags) = MsgFlags::from_bits(args.flags) else {
668 warn_once_then_debug!("Unrecognized recv flags: {:#b}", args.flags);
669 return Err(Errno::EINVAL.into());
670 };
671
672 let mut packet_buffer = Vec::new();
673 let (_rv, _num_removed_from_buf) =
674 common.recvmsg(socket, &mut packet_buffer, flags, mem, cb_queue)?;
675 self.refresh_file_state(common, FileSignals::empty(), cb_queue);
676
677 let mut writer = IoVecWriter::new(args.iovs, mem);
678
679 let src_addr = SockaddrStorage::from_netlink(&NetlinkAddr::new(0, 0));
681
682 if packet_buffer.len() < std::mem::size_of::<nlmsghdr>() {
683 log::warn!("The processed packet is too short");
684 return Err(Errno::EINVAL.into());
685 }
686
687 let buffer = {
688 let nlmsg_type = &packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_type)];
689 let nlmsg_type = u16::from_ne_bytes(nlmsg_type.try_into().unwrap());
690
691 match nlmsg_type {
692 RTM_GETLINK => {
693 let nlmsghdr_len = std::mem::size_of::<nlmsghdr>();
694 let ifinfomsg_len = std::mem::size_of::<ifinfomsg>();
695 let header_len = nlmsghdr_len + ifinfomsg_len;
696
697 if (nlmsghdr_len..header_len).contains(&packet_buffer.len()) {
703 log::debug!(
704 "Padding the RTM_GETLINK with zeroes to meet the minimum length"
705 );
706 packet_buffer.resize(header_len, 0);
707 packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_len)]
708 .copy_from_slice(&(header_len as u32).to_ne_bytes()[..]);
709 }
710 self.handle_ifinfomsg(common, &packet_buffer[..])
711 }
712 RTM_GETADDR => {
713 let nlmsghdr_len = std::mem::size_of::<nlmsghdr>();
714 let ifaddrmsg_len = std::mem::size_of::<ifaddrmsg>();
715 let header_len = nlmsghdr_len + ifaddrmsg_len;
716
717 if (nlmsghdr_len..header_len).contains(&packet_buffer.len()) {
723 log::debug!(
724 "Padding the RTM_GETADDR with zeroes to meet the minimum length"
725 );
726 packet_buffer.resize(header_len, 0);
727 packet_buffer[memoffset::span_of!(nlmsghdr, nlmsg_len)]
728 .copy_from_slice(&(header_len as u32).to_ne_bytes()[..]);
729 }
730 self.handle_ifaddrmsg(common, &packet_buffer[..])
731 }
732 _ => {
733 warn_once_then_debug!(
734 "Found unsupported nlmsg_type: {nlmsg_type} (only RTM_GETLINK
735 and RTM_GETADDR are supported)"
736 );
737 self.handle_error(&packet_buffer[..])
738 }
739 }
740 };
741
742 let mut total_copied = 0;
744 let mut buf = buffer.as_slice();
745 while !buf.is_empty() {
746 match writer.write(buf) {
747 Ok(0) => break,
748 Ok(n) => {
749 buf = &buf[n..];
750 total_copied += n;
751 }
752 Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
753 Err(e) => return Err(e.into()),
754 }
755 }
756
757 let return_val = if flags.contains(MsgFlags::MSG_TRUNC) {
758 buffer.len()
759 } else {
760 total_copied
761 };
762
763 Ok(RecvmsgReturn {
764 return_val: return_val.try_into().unwrap(),
765 addr: Some(src_addr),
766 msg_flags: 0,
767 control_len: 0,
768 })
769 }
770
771 fn handle_error(&self, bytes: &[u8]) -> Vec<u8> {
772 let nlmsg_seq = match bytes.get(memoffset::span_of!(nlmsghdr, nlmsg_seq)) {
774 Some(x) => u32::from_ne_bytes(x.try_into().unwrap()),
775 None => 0,
776 };
777
778 let msg = {
780 let len = None;
781 let nl_type = Nlmsg::Error;
782 let flags = NlmFFlags::empty();
783 let pid = None;
784 let payload = NlPayload::<Nlmsg, ()>::Empty;
785 Nlmsghdr::new(len, nl_type, flags, Some(nlmsg_seq), pid, payload)
786 };
787
788 let mut buffer = Cursor::new(Vec::new());
789 msg.to_bytes(&mut buffer).unwrap();
790 buffer.into_inner()
791 }
792
793 fn handle_ifaddrmsg(&self, common: &mut NetlinkSocketCommon, bytes: &[u8]) -> Vec<u8> {
794 let Ok(nlmsg) = Nlmsghdr::<Rtm, Ifaddrmsg>::from_bytes(&mut Cursor::new(bytes)) else {
795 log::warn!("Failed to deserialize the message");
796 return self.handle_error(bytes);
797 };
798
799 let Ok(ifaddrmsg) = nlmsg.get_payload() else {
800 log::warn!("Failed to find the payload");
801 return self.handle_error(bytes);
802 };
803
804 if ifaddrmsg.ifa_family != RtAddrFamily::Unspecified
806 && ifaddrmsg.ifa_family != RtAddrFamily::Inet
807 {
808 log::warn!("Unsupported ifa_family (only AF_UNSPEC and AF_INET are supported)");
809 return self.handle_error(bytes);
810 }
811
812 if ifaddrmsg.ifa_prefixlen != 0
814 || ifaddrmsg.ifa_flags != IfaFFlags::empty()
815 || ifaddrmsg.ifa_index != 0
816 || ifaddrmsg.ifa_scope != libc::c_uchar::from(RtScope::Universe)
817 {
818 log::warn!(
819 "Unsupported ifa_prefixlen, ifa_flags, ifa_scope, or ifa_index (they have to be 0)"
820 );
821 return self.handle_error(bytes);
822 }
823
824 let mut buffer = Cursor::new(Vec::new());
825 for interface in &common.interfaces {
827 let address = interface.address.octets();
828 let broadcast = Ipv4Addr::from(
829 0xffff_ffff_u32
830 .checked_shr(u32::from(interface.prefix_len))
831 .unwrap_or(0)
832 | u32::from(interface.address),
833 )
834 .octets();
835 let mut label = Vec::from(interface.label.as_bytes());
836 label.push(0); let attrs = [
840 Rtattr::new(None, Ifa::Address, Buffer::from(&address[..])).unwrap(),
844 Rtattr::new(None, Ifa::Local, Buffer::from(&address[..])).unwrap(),
845 Rtattr::new(None, Ifa::Broadcast, Buffer::from(&broadcast[..])).unwrap(),
846 Rtattr::new(None, Ifa::Label, Buffer::from(label)).unwrap(),
847 ];
848 let ifaddrmsg = Ifaddrmsg {
849 ifa_family: RtAddrFamily::Inet,
850 ifa_prefixlen: interface.prefix_len,
851 ifa_flags: IfaFFlags::new(&[IfaF::Permanent]),
853 ifa_scope: libc::c_uchar::from(interface.scope),
854 ifa_index: interface.index,
855 rtattrs: RtBuffer::from_iter(attrs),
856 };
857 let nlmsg = {
858 let len = None;
859 let nl_type = Rtm::Newaddr;
860 let flags = NlmFFlags::new(&[NlmF::Multi]);
862 let seq = Some(nlmsg.nl_seq);
864 let pid = None;
865 let payload = NlPayload::Payload(ifaddrmsg);
866 Nlmsghdr::new(len, nl_type, flags, seq, pid, payload)
867 };
868 nlmsg.to_bytes(&mut buffer).unwrap();
869 }
870 let done_msg = {
872 let len = None;
873 let nl_type = Nlmsg::Done;
874 let flags = NlmFFlags::new(&[NlmF::Multi]);
875 let seq = Some(nlmsg.nl_seq);
877 let pid = None;
878 let payload: NlPayload<Nlmsg, u32> = NlPayload::Payload(0);
882 Nlmsghdr::new(len, nl_type, flags, seq, pid, payload)
883 };
884 done_msg.to_bytes(&mut buffer).unwrap();
885
886 buffer.into_inner()
887 }
888
889 fn handle_ifinfomsg(&self, common: &mut NetlinkSocketCommon, bytes: &[u8]) -> Vec<u8> {
890 let Ok(nlmsg) = Nlmsghdr::<Rtm, Ifinfomsg>::from_bytes(&mut Cursor::new(bytes)) else {
891 log::warn!("Failed to deserialize the message");
892 return self.handle_error(bytes);
893 };
894
895 let Ok(ifinfomsg) = nlmsg.get_payload() else {
896 log::warn!("Failed to find the payload");
897 return self.handle_error(bytes);
898 };
899
900 if ifinfomsg.ifi_family != RtAddrFamily::Unspecified
902 && ifinfomsg.ifi_family != RtAddrFamily::Inet
903 {
904 warn_once_then_debug!(
905 "Unsupported ifi_family (only AF_UNSPEC and AF_INET are supported)"
906 );
907 return self.handle_error(bytes);
908 }
909
910 if ifinfomsg.ifi_type != 0.into()
912 || ifinfomsg.ifi_index != 0
913 || ifinfomsg.ifi_flags != IffFlags::empty()
914 {
915 warn_once_then_debug!(
916 "Unsupported ifi_type, ifi_index, or ifi_flags (they have to be 0)"
917 );
918 return self.handle_error(bytes);
919 }
920
921 let mut buffer = Cursor::new(Vec::new());
925 for interface in &common.interfaces {
927 let mut label = Vec::from(interface.label.as_bytes());
928 label.push(0); let attrs = [
932 Rtattr::new(None, Ifla::Ifname, Buffer::from(label)).unwrap(),
933 Rtattr::new(None, Ifla::Txqlen, Buffer::from(&1000u32.to_le_bytes()[..])).unwrap(),
937 Rtattr::new(
938 None,
939 Ifla::Mtu,
940 Buffer::from(&interface.mtu.to_le_bytes()[..]),
941 )
942 .unwrap(),
943 ];
945 let flags = if interface.if_type == Arphrd::Loopback {
946 IffFlags::new(&[Iff::Up, Iff::Loopback, Iff::Running])
947 } else {
948 IffFlags::new(&[Iff::Up, Iff::Broadcast, Iff::Running, Iff::Multicast])
950 };
951 let ifinfomsg = Ifinfomsg::new(
952 RtAddrFamily::Inet,
953 interface.if_type,
954 interface.index,
955 flags,
956 IffFlags::from_bitmask(0xffffffff), RtBuffer::from_iter(attrs),
958 );
959 let nlmsg = {
960 let len = None;
961 let nl_type = Rtm::Newlink;
962 let flags = NlmFFlags::new(&[NlmF::Multi]);
964 let seq = Some(nlmsg.nl_seq);
966 let pid = None;
967 let payload = NlPayload::Payload(ifinfomsg);
968 Nlmsghdr::new(len, nl_type, flags, seq, pid, payload)
969 };
970 nlmsg.to_bytes(&mut buffer).unwrap();
971 }
972 let done_msg = {
974 let len = None;
975 let nl_type = Nlmsg::Done;
976 let flags = NlmFFlags::new(&[NlmF::Multi]);
977 let seq = Some(nlmsg.nl_seq);
979 let pid = None;
980 let payload: NlPayload<Nlmsg, u32> = NlPayload::Payload(0);
984 Nlmsghdr::new(len, nl_type, flags, seq, pid, payload)
985 };
986 done_msg.to_bytes(&mut buffer).unwrap();
987
988 buffer.into_inner()
989 }
990}
991
992impl ClosedState {
993 fn bound_address(&self) -> Result<Option<NetlinkAddr>, Errno> {
994 Ok(None)
995 }
996
997 fn refresh_file_state(
998 &self,
999 common: &mut NetlinkSocketCommon,
1000 signals: FileSignals,
1001 cb_queue: &mut CallbackQueue,
1002 ) {
1003 common.update_state(
1004 FileState::all(),
1005 FileState::CLOSED,
1006 signals,
1007 cb_queue,
1008 );
1009 }
1010
1011 fn close(
1012 self,
1013 _common: &mut NetlinkSocketCommon,
1014 _cb_queue: &mut CallbackQueue,
1015 ) -> (ProtocolState, Result<(), SyscallError>) {
1016 panic!("Trying to close an already closed socket");
1018 }
1019
1020 fn bind(
1021 &mut self,
1022 _common: &mut NetlinkSocketCommon,
1023 _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1024 _addr: Option<&SockaddrStorage>,
1025 _rng: impl rand::Rng,
1026 ) -> Result<(), SyscallError> {
1027 log::warn!("bind() while in state {}", std::any::type_name::<Self>());
1029 Err(Errno::EOPNOTSUPP.into())
1030 }
1031
1032 fn sendmsg(
1033 &mut self,
1034 _common: &mut NetlinkSocketCommon,
1035 _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1036 _args: SendmsgArgs,
1037 _mem: &mut MemoryManager,
1038 _cb_queue: &mut CallbackQueue,
1039 ) -> Result<libc::ssize_t, SyscallError> {
1040 log::warn!("sendmsg() while in state {}", std::any::type_name::<Self>());
1042 Err(Errno::EOPNOTSUPP.into())
1043 }
1044
1045 fn recvmsg(
1046 &mut self,
1047 _common: &mut NetlinkSocketCommon,
1048 _socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1049 _args: RecvmsgArgs,
1050 _mem: &mut MemoryManager,
1051 _cb_queue: &mut CallbackQueue,
1052 ) -> Result<RecvmsgReturn, SyscallError> {
1053 log::warn!("recvmsg() while in state {}", std::any::type_name::<Self>());
1055 Err(Errno::EOPNOTSUPP.into())
1056 }
1057}
1058
1059struct Interface {
1061 address: Ipv4Addr,
1062 label: String,
1063 prefix_len: u8,
1064 if_type: Arphrd,
1065 mtu: u32,
1066 scope: RtScope,
1067 index: libc::c_int,
1068}
1069
1070struct NetlinkSocketCommon {
1072 buffer: Arc<AtomicRefCell<SharedBuf>>,
1073 send_limit: u64,
1075 sent_len: u64,
1077 event_source: StateEventSource,
1078 state: FileState,
1079 status: FileStatus,
1080 has_open_file: bool,
1083 interfaces: Vec<Interface>,
1085}
1086
1087impl NetlinkSocketCommon {
1088 pub fn supports_sa_restart(&self) -> bool {
1089 true
1090 }
1091
1092 pub fn sendmsg(
1093 &mut self,
1094 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1095 iovs: &[IoVec],
1096 flags: libc::c_int,
1097 mem: &mut MemoryManager,
1098 cb_queue: &mut CallbackQueue,
1099 ) -> Result<usize, SyscallError> {
1100 let supported_flags = MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_NOSIGNAL | MsgFlags::MSG_TRUNC;
1103
1104 let Some(mut flags) = MsgFlags::from_bits(flags) else {
1107 warn_once_then_debug!("Unrecognized send flags: {:#b}", flags);
1108 return Err(Errno::EINVAL.into());
1109 };
1110 if flags.intersects(!supported_flags) {
1111 warn_once_then_debug!("Unsupported send flags: {:?}", flags);
1112 return Err(Errno::EINVAL.into());
1113 }
1114
1115 if self.status.contains(FileStatus::NONBLOCK) {
1116 flags.insert(MsgFlags::MSG_DONTWAIT);
1117 }
1118
1119 let result = (|| {
1121 let len = iovs.iter().map(|x| x.len).sum::<libc::size_t>();
1122
1123 let space_available = self
1126 .send_limit
1127 .saturating_sub(self.sent_len)
1128 .try_into()
1129 .unwrap();
1130
1131 if space_available == 0 {
1132 return Err(Errno::EAGAIN);
1133 }
1134
1135 if len > space_available {
1136 if len <= self.send_limit.try_into().unwrap() {
1137 return Err(Errno::EAGAIN);
1139 } else {
1140 return Err(Errno::EMSGSIZE);
1142 }
1143 }
1144
1145 let reader = IoVecReader::new(iovs, mem);
1146 let reader = reader.take(len.try_into().unwrap());
1147
1148 self.buffer
1151 .borrow_mut()
1152 .write_packet(reader, len, cb_queue)
1153 .map_err(|e| Errno::try_from(e).unwrap())?;
1154
1155 self.sent_len += u64::try_from(len).unwrap();
1157 Ok(len)
1158 })();
1159
1160 if result.as_ref().err() == Some(&Errno::EWOULDBLOCK)
1162 && !flags.contains(MsgFlags::MSG_DONTWAIT)
1163 {
1164 return Err(SyscallError::new_blocked_on_file(
1165 File::Socket(Socket::Netlink(socket.clone())),
1166 FileState::WRITABLE,
1167 self.supports_sa_restart(),
1168 ));
1169 }
1170
1171 Ok(result?)
1172 }
1173
1174 pub fn recvmsg<W: Write>(
1175 &mut self,
1176 socket: &Arc<AtomicRefCell<NetlinkSocket>>,
1177 dst: W,
1178 mut flags: MsgFlags,
1179 _mem: &mut MemoryManager,
1180 cb_queue: &mut CallbackQueue,
1181 ) -> Result<(usize, usize), SyscallError> {
1182 let supported_flags = MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_PEEK | MsgFlags::MSG_TRUNC;
1183
1184 if flags.intersects(!supported_flags) {
1187 warn_once_then_debug!("Unsupported recv flags: {:?}", flags);
1188 return Err(Errno::EINVAL.into());
1189 }
1190
1191 if self.status.contains(FileStatus::NONBLOCK) {
1192 flags.insert(MsgFlags::MSG_DONTWAIT);
1193 }
1194
1195 let result = (|| {
1197 let mut buffer = self.buffer.borrow_mut();
1198
1199 if !buffer.has_data() {
1201 return Err(Errno::EWOULDBLOCK);
1202 }
1203
1204 let (num_copied, num_removed_from_buf) = if flags.contains(MsgFlags::MSG_PEEK) {
1205 buffer.peek(dst).map_err(|e| Errno::try_from(e).unwrap())?
1206 } else {
1207 buffer
1208 .read(dst, cb_queue)
1209 .map_err(|e| Errno::try_from(e).unwrap())?
1210 };
1211
1212 if flags.contains(MsgFlags::MSG_TRUNC) {
1213 Ok((num_removed_from_buf, num_removed_from_buf))
1215 } else {
1216 Ok((num_copied, num_removed_from_buf))
1217 }
1218 })();
1219
1220 if result.as_ref().err() == Some(&Errno::EWOULDBLOCK)
1222 && !flags.contains(MsgFlags::MSG_DONTWAIT)
1223 {
1224 return Err(SyscallError::new_blocked_on_file(
1225 File::Socket(Socket::Netlink(socket.clone())),
1226 FileState::READABLE,
1227 self.supports_sa_restart(),
1228 ));
1229 }
1230
1231 Ok(result?)
1232 }
1233
1234 fn update_state(
1235 &mut self,
1236 mask: FileState,
1237 state: FileState,
1238 signals: FileSignals,
1239 cb_queue: &mut CallbackQueue,
1240 ) {
1241 let old_state = self.state;
1242
1243 self.state.remove(mask);
1245 self.state.insert(state & mask);
1246
1247 self.handle_state_change(old_state, signals, cb_queue);
1248 }
1249
1250 fn handle_state_change(
1251 &mut self,
1252 old_state: FileState,
1253 signals: FileSignals,
1254 cb_queue: &mut CallbackQueue,
1255 ) {
1256 let states_changed = self.state ^ old_state;
1257
1258 if states_changed.is_empty() && signals.is_empty() {
1260 return;
1261 }
1262
1263 self.event_source
1264 .notify_listeners(self.state, states_changed, signals, cb_queue);
1265 }
1266}
1267
1268#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
1269pub enum NetlinkSocketType {
1270 Dgram,
1271 Raw,
1272}
1273
1274impl TryFrom<libc::c_int> for NetlinkSocketType {
1275 type Error = NetlinkSocketTypeConversionError;
1276 fn try_from(val: libc::c_int) -> Result<Self, Self::Error> {
1277 match val {
1278 libc::SOCK_DGRAM => Ok(Self::Dgram),
1279 libc::SOCK_RAW => Ok(Self::Raw),
1280 x => Err(NetlinkSocketTypeConversionError(x)),
1281 }
1282 }
1283}
1284
1285#[derive(Copy, Clone, Debug)]
1286pub struct NetlinkSocketTypeConversionError(libc::c_int);
1287
1288impl std::error::Error for NetlinkSocketTypeConversionError {}
1289
1290impl std::fmt::Display for NetlinkSocketTypeConversionError {
1291 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1292 write!(
1293 f,
1294 "Invalid socket type {}; netlink sockets only support SOCK_DGRAM and SOCK_RAW",
1295 self.0
1296 )
1297 }
1298}
1299
1300#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
1301pub enum NetlinkFamily {
1302 Route,
1303}
1304
1305impl TryFrom<libc::c_int> for NetlinkFamily {
1306 type Error = NetlinkFamilyConversionError;
1307 fn try_from(val: libc::c_int) -> Result<Self, Self::Error> {
1308 match val {
1309 libc::NETLINK_ROUTE => Ok(Self::Route),
1310 x => Err(NetlinkFamilyConversionError(x)),
1311 }
1312 }
1313}
1314
1315#[derive(Copy, Clone, Debug)]
1316pub struct NetlinkFamilyConversionError(libc::c_int);
1317
1318impl std::error::Error for NetlinkFamilyConversionError {}
1319
1320impl std::fmt::Display for NetlinkFamilyConversionError {
1321 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1322 write!(
1323 f,
1324 "Invalid netlink family {}; netlink families only support NETLINK_ROUTE",
1325 self.0
1326 )
1327 }
1328}