shadow_rs/host/descriptor/socket/inet/
udp.rs

1use std::collections::LinkedList;
2use std::io::{Read, Write};
3use std::net::{Ipv4Addr, SocketAddrV4};
4use std::sync::Arc;
5
6use atomic_refcell::AtomicRefCell;
7use bytes::{Bytes, BytesMut};
8use linux_api::errno::Errno;
9use linux_api::ioctls::IoctlRequest;
10use linux_api::socket::Shutdown;
11use nix::sys::socket::{MsgFlags, SockaddrIn};
12use shadow_shim_helper_rs::emulated_time::EmulatedTime;
13use shadow_shim_helper_rs::syscall_types::ForeignPtr;
14
15use crate::core::worker::Worker;
16use crate::cshadow as c;
17use crate::host::descriptor::listener::{StateEventSource, StateListenHandle, StateListenerFilter};
18use crate::host::descriptor::socket::inet::{self, InetSocket};
19use crate::host::descriptor::socket::{RecvmsgArgs, RecvmsgReturn, SendmsgArgs, ShutdownFlags};
20use crate::host::descriptor::{
21    File, FileMode, FileSignals, FileState, FileStatus, OpenFile, Socket, SyscallResult,
22};
23use crate::host::memory_manager::MemoryManager;
24use crate::host::network::interface::FifoPacketPriority;
25use crate::host::network::namespace::{AssociationHandle, NetworkNamespace};
26use crate::host::syscall::io::{IoVec, IoVecReader, IoVecWriter, write_partial};
27use crate::host::syscall::types::SyscallError;
28use crate::network::packet::{PacketRc, PacketStatus};
29use crate::utility::callback_queue::CallbackQueue;
30use crate::utility::sockaddr::SockaddrStorage;
31use crate::utility::{HostTreePointer, ObjectCounter};
32
33/// Maximum size of a datagram we are allowed to send out over the network.
34// 65,535 (2^16 - 1) - 20 (ip header) - 8 (udp header)
35const CONFIG_DATAGRAM_MAX_SIZE: usize = 65507;
36
37pub struct UdpSocket {
38    event_source: StateEventSource,
39    status: FileStatus,
40    state: FileState,
41    shutdown_status: ShutdownFlags,
42    send_buffer: MessageBuffer<MessageSendHeader>,
43    recv_buffer: MessageBuffer<MessageRecvHeader>,
44    peer_addr: Option<SocketAddrV4>,
45    bound_addr: Option<SocketAddrV4>,
46    association: Option<AssociationHandle>,
47    /// The receive time of the last packet returned to the managed process during a call to
48    /// `recvmsg()`. Used for `SIOCGSTAMP`.
49    recv_time_of_last_read_packet: Option<EmulatedTime>,
50    // should only be used by `OpenFile` to make sure there is only ever one `OpenFile` instance for
51    // this file
52    has_open_file: bool,
53    _counter: ObjectCounter,
54}
55
56impl UdpSocket {
57    pub fn new(
58        status: FileStatus,
59        send_buf_size: usize,
60        recv_buf_size: usize,
61    ) -> Arc<AtomicRefCell<Self>> {
62        let mut socket = Self {
63            event_source: StateEventSource::new(),
64            status,
65            state: FileState::ACTIVE,
66            shutdown_status: ShutdownFlags::empty(),
67            send_buffer: MessageBuffer::new(send_buf_size),
68            recv_buffer: MessageBuffer::new(recv_buf_size),
69            peer_addr: None,
70            bound_addr: None,
71            association: None,
72            recv_time_of_last_read_packet: None,
73            has_open_file: false,
74            _counter: ObjectCounter::new("UdpSocket"),
75        };
76
77        CallbackQueue::queue_and_run_with_legacy(|cb_queue| {
78            socket.refresh_readable_writable(FileSignals::empty(), cb_queue)
79        });
80
81        Arc::new(AtomicRefCell::new(socket))
82    }
83
84    pub fn status(&self) -> FileStatus {
85        self.status
86    }
87
88    pub fn set_status(&mut self, status: FileStatus) {
89        self.status = status;
90    }
91
92    pub fn mode(&self) -> FileMode {
93        FileMode::READ | FileMode::WRITE
94    }
95
96    pub fn has_open_file(&self) -> bool {
97        self.has_open_file
98    }
99
100    pub fn supports_sa_restart(&self) -> bool {
101        true
102    }
103
104    pub fn set_has_open_file(&mut self, val: bool) {
105        self.has_open_file = val;
106    }
107
108    pub fn push_in_packet(
109        &mut self,
110        packet: PacketRc,
111        cb_queue: &mut CallbackQueue,
112        recv_time: EmulatedTime,
113    ) {
114        packet.add_status(PacketStatus::RcvSocketProcessed);
115
116        if let Some(peer_addr) = self.peer_addr {
117            if peer_addr != packet.src_ipv4_address() {
118                // connect(2): "If the socket sockfd is of type SOCK_DGRAM, then addr is the address
119                // to which datagrams are sent by default, and the only address from which datagrams
120                // are received."
121
122                // we have a peer, but received a packet from a different source address than that
123                // peer
124                packet.add_status(PacketStatus::RcvSocketDropped);
125
126                // TODO: There's a race condition where we check the packet's address only when
127                // receiving the packet from the network interface, but the user could call
128                // `connect()` to set a peer after we've already received and buffered this packet.
129                // My guess is that this race condition exists in Linux as well, but ideally we
130                // should add a test, and do another check when `recvmsg()` is called if we really
131                // need to.
132
133                return;
134            }
135        };
136
137        // TODO: also check the dst address to make sure we are the intended socket?
138
139        // don't bother copying the bytes if we know the push will fail
140        if !self.recv_buffer.has_space() {
141            packet.add_status(PacketStatus::RcvSocketDropped);
142            return;
143        }
144
145        // transfer the `Bytes` directly from the packet to the buffer without copying the bytes.
146        // we use concat in case the payload has mutiple chunks, but that should not happen in
147        // the normal case since we only send UDP messages with a single `Bytes` object.
148
149        let payload = tcp::Payload(packet.payload());
150        assert_eq!(payload.len() as usize, packet.payload_len());
151        let message = payload.concat();
152
153        let header = MessageRecvHeader {
154            src: packet.src_ipv4_address(),
155            dst: packet.dst_ipv4_address(),
156            recv_time,
157        };
158
159        // push the message to the receive buffer (shouldn't fail since we checked for available
160        // space above)
161        self.recv_buffer.push_message(message, header).unwrap();
162
163        log::trace!("Added a packet to the UDP socket's recv buffer");
164        packet.add_status(PacketStatus::RcvSocketBuffered);
165
166        self.refresh_readable_writable(FileSignals::READ_BUFFER_GREW, cb_queue);
167    }
168
169    pub fn pull_out_packet(&mut self, cb_queue: &mut CallbackQueue) -> Option<PacketRc> {
170        // pop the message from the send buffer
171        let Some((message, header)) = self.send_buffer.pop_message() else {
172            log::debug!(
173                "Attempted to remove a message from the UDP socket's send buffer, but none available"
174            );
175
176            return None;
177        };
178
179        log::trace!("Removed a message from the UDP socket's send buffer");
180
181        // We transfer the `Bytes` directly from the buffer to the packet without copying them.
182        let packet =
183            PacketRc::new_ipv4_udp(header.src, header.dst, message, header.packet_priority);
184        packet.add_status(PacketStatus::SndCreated);
185
186        self.refresh_readable_writable(FileSignals::empty(), cb_queue);
187
188        Some(packet)
189    }
190
191    pub fn peek_next_packet_priority(&self) -> Option<FifoPacketPriority> {
192        self.send_buffer.buffer.front().map(|x| x.1.packet_priority)
193    }
194
195    pub fn has_data_to_send(&self) -> bool {
196        !self.send_buffer.is_empty()
197    }
198
199    pub fn getsockname(&self) -> Result<Option<SockaddrIn>, Errno> {
200        let mut addr = self
201            .bound_addr
202            .unwrap_or(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0));
203
204        // if we are bound to INADDR_ANY, we should instead return the IP used to communicate with
205        // the connected peer (if we have one)
206        if *addr.ip() == Ipv4Addr::UNSPECIFIED {
207            if let Some(peer_addr) = self.peer_addr {
208                addr.set_ip(*peer_addr.ip());
209            }
210        }
211
212        Ok(Some(addr.into()))
213    }
214
215    pub fn getpeername(&self) -> Result<Option<SockaddrIn>, Errno> {
216        Ok(Some(self.peer_addr.ok_or(Errno::ENOTCONN)?.into()))
217    }
218
219    pub fn address_family(&self) -> linux_api::socket::AddressFamily {
220        linux_api::socket::AddressFamily::AF_INET
221    }
222
223    pub fn close(&mut self, cb_queue: &mut CallbackQueue) -> Result<(), SyscallError> {
224        // drop the existing association handle to disassociate the socket
225        self.association = None;
226
227        self.update_state(
228            /* mask= */ FileState::all(),
229            FileState::CLOSED,
230            FileSignals::empty(),
231            cb_queue,
232        );
233        Ok(())
234    }
235
236    pub fn bind(
237        socket: &Arc<AtomicRefCell<Self>>,
238        addr: Option<&SockaddrStorage>,
239        net_ns: &NetworkNamespace,
240        rng: impl rand::Rng,
241    ) -> Result<(), SyscallError> {
242        // if the address pointer was NULL
243        let Some(addr) = addr else {
244            return Err(Errno::EFAULT.into());
245        };
246
247        // if not an inet socket address
248        let Some(addr) = addr.as_inet() else {
249            return Err(Errno::EINVAL.into());
250        };
251
252        let addr: SocketAddrV4 = (*addr).into();
253
254        {
255            let socket = socket.borrow();
256
257            // if the socket is already bound
258            if socket.bound_addr.is_some() {
259                return Err(Errno::EINVAL.into());
260            }
261
262            // Since we're not bound, we must not have a peer. We may have a peer in the future if
263            // `connect()` is called on this socket.
264            assert!(socket.peer_addr.is_none());
265
266            // must not have been associated with the network interface
267            assert!(socket.association.is_none());
268        }
269
270        // this will allow us to receive packets from any peer
271        let unspecified_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
272
273        // associate the socket
274        let (addr, handle) = inet::associate_socket(
275            InetSocket::Udp(Arc::clone(socket)),
276            addr,
277            unspecified_addr,
278            /* check_generic_peer= */ true,
279            net_ns,
280            rng,
281        )?;
282
283        // update the socket's local address
284        {
285            let mut socket = socket.borrow_mut();
286            socket.bound_addr = Some(addr);
287            socket.association = Some(handle);
288        }
289
290        Ok(())
291    }
292
293    pub fn readv(
294        &mut self,
295        _iovs: &[IoVec],
296        _offset: Option<libc::off_t>,
297        _flags: libc::c_int,
298        _mem: &mut MemoryManager,
299        _cb_queue: &mut CallbackQueue,
300    ) -> Result<libc::ssize_t, SyscallError> {
301        // we could call UdpSocket::recvmsg() here, but for now we expect that there are no code
302        // paths that would call UdpSocket::readv() since the readv() syscall handler should have
303        // called UdpSocket::recvmsg() instead
304        panic!("Called UdpSocket::readv() on a UDP socket");
305    }
306
307    pub fn writev(
308        &mut self,
309        _iovs: &[IoVec],
310        _offset: Option<libc::off_t>,
311        _flags: libc::c_int,
312        _mem: &mut MemoryManager,
313        _cb_queue: &mut CallbackQueue,
314    ) -> Result<libc::ssize_t, SyscallError> {
315        // we could call UdpSocket::sendmsg() here, but for now we expect that there are no code
316        // paths that would call UdpSocket::writev() since the writev() syscall handler should have
317        // called UdpSocket::sendmsg() instead
318        panic!("Called UdpSocket::writev() on a UDP socket");
319    }
320
321    pub fn sendmsg(
322        socket: &Arc<AtomicRefCell<Self>>,
323        args: SendmsgArgs,
324        mem: &mut MemoryManager,
325        net_ns: &NetworkNamespace,
326        rng: impl rand::Rng,
327        cb_queue: &mut CallbackQueue,
328    ) -> Result<libc::ssize_t, SyscallError> {
329        let mut socket_ref = socket.borrow_mut();
330
331        // if the file's writing has been shut down, return EPIPE
332        if socket_ref.shutdown_status.contains(ShutdownFlags::WRITE) {
333            return Err(linux_api::errno::Errno::EPIPE.into());
334        }
335
336        let Some(mut flags) = MsgFlags::from_bits(args.flags) else {
337            log::debug!("Unrecognized send flags: {:#b}", args.flags);
338            return Err(Errno::EINVAL.into());
339        };
340
341        // TODO: If we have a peer AND a destination address is provided, should we use the peer or
342        // the destination address? Do we have a test for this?
343        let dst_addr = match args.addr {
344            Some(addr) => match addr.as_inet() {
345                // an inet socket address
346                Some(x) => (*x).into(),
347                // not an inet socket address
348                None => return Err(Errno::EAFNOSUPPORT.into()),
349            },
350            // no destination address provided
351            None => match socket_ref.peer_addr {
352                Some(x) => x,
353                None => return Err(Errno::EDESTADDRREQ.into()),
354            },
355        };
356
357        if socket_ref.status().contains(FileStatus::NONBLOCK) {
358            flags.insert(MsgFlags::MSG_DONTWAIT);
359        }
360
361        let len: libc::size_t = args.iovs.iter().map(|x| x.len).sum();
362
363        // TODO: should use IP fragmentation to make sure packets fit within the MTU
364        if len > CONFIG_DATAGRAM_MAX_SIZE {
365            return Err(linux_api::errno::Errno::EMSGSIZE.into());
366        }
367
368        // make sure that we're bound
369        if let Some(bound_addr) = socket_ref.bound_addr {
370            // we must have an association since we're bound
371            assert!(socket_ref.association.is_some());
372
373            // make sure the new peer address is connectable from the bound interface
374            if !bound_addr.ip().is_unspecified() {
375                // assume that a socket bound to 0.0.0.0 can connect anywhere, so only check
376                // localhost
377                match (
378                    bound_addr.ip() == &Ipv4Addr::LOCALHOST,
379                    dst_addr.ip() == &Ipv4Addr::LOCALHOST,
380                ) {
381                    // bound and peer on loopback interface
382                    (true, true) => {}
383                    // neither bound nor peer on loopback interface (shadow treats any
384                    // non-127.0.0.1 address as an "internet" address)
385                    (false, false) => {}
386                    _ => return Err(Errno::EINVAL.into()),
387                }
388            }
389        } else {
390            // we can't be unbound but have a peer
391            assert!(socket_ref.peer_addr.is_none());
392            assert!(socket_ref.association.is_none());
393
394            // implicit bind to 0.0.0.0
395            let local_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
396
397            // this will allow us to receive packets from any peer
398            let unspecified_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
399
400            let (local_addr, handle) = super::associate_socket(
401                InetSocket::Udp(Arc::clone(socket)),
402                local_addr,
403                unspecified_addr,
404                /* check_generic_peer= */ true,
405                net_ns,
406                rng,
407            )?;
408
409            socket_ref.bound_addr = Some(local_addr);
410            socket_ref.association = Some(handle);
411        }
412
413        // run in a closure so that an early return doesn't skip checking if we should block
414        let result = (|| {
415            // don't bother copying the bytes if we know the push will fail
416            if !socket_ref.send_buffer.has_space() {
417                return Err(Errno::EWOULDBLOCK);
418            }
419
420            // write the iovs to an empty message
421            let mut reader = IoVecReader::new(args.iovs, mem);
422            let mut message = BytesMut::zeroed(len);
423            reader
424                .read_exact(&mut message[..])
425                .map_err(|e| Errno::try_from(e).unwrap())?;
426
427            // get the priority that we'll assign to the eventual packet
428            let packet_priority =
429                Worker::with_active_host(|host| host.get_next_packet_priority()).unwrap();
430
431            let src_addr = socket_ref.bound_addr.unwrap();
432            let src_addr = if src_addr.ip().is_unspecified() {
433                // depending on the destination address, choose either localhost or the public IP
434                // address
435                if dst_addr.ip() == &std::net::Ipv4Addr::LOCALHOST {
436                    SocketAddrV4::new(Ipv4Addr::LOCALHOST, src_addr.port())
437                } else {
438                    SocketAddrV4::new(net_ns.default_ip, src_addr.port())
439                }
440            } else {
441                src_addr
442            };
443
444            let header = MessageSendHeader {
445                src: src_addr,
446                dst: dst_addr,
447                packet_priority,
448            };
449
450            // push the message to the send buffer (shouldn't fail since we checked for available
451            // space above)
452            socket_ref
453                .send_buffer
454                .push_message(message.freeze(), header)
455                .unwrap();
456
457            // notify the host that this socket has packets to send
458            let socket = Arc::clone(socket);
459            let interface_ip = *socket_ref.bound_addr.unwrap().ip();
460            cb_queue.add(move |_cb_queue| {
461                Worker::with_active_host(|host| {
462                    let socket = InetSocket::Udp(socket);
463                    host.notify_socket_has_packets(interface_ip, &socket);
464                })
465                .unwrap();
466            });
467
468            Ok(len)
469        })();
470
471        socket_ref.refresh_readable_writable(FileSignals::empty(), cb_queue);
472
473        // if the syscall would block and we don't have the MSG_DONTWAIT flag
474        if result == Err(Errno::EWOULDBLOCK) && !flags.contains(MsgFlags::MSG_DONTWAIT) {
475            return Err(SyscallError::new_blocked_on_file(
476                File::Socket(Socket::Inet(InetSocket::Udp(socket.clone()))),
477                FileState::WRITABLE,
478                socket_ref.supports_sa_restart(),
479            ));
480        }
481
482        Ok(result?.try_into().unwrap())
483    }
484
485    pub fn recvmsg(
486        socket: &Arc<AtomicRefCell<Self>>,
487        args: RecvmsgArgs,
488        mem: &mut MemoryManager,
489        cb_queue: &mut CallbackQueue,
490    ) -> Result<RecvmsgReturn, SyscallError> {
491        let socket_ref = &mut *socket.borrow_mut();
492
493        let Some(mut flags) = MsgFlags::from_bits(args.flags) else {
494            log::debug!("Unrecognized recv flags: {:#b}", args.flags);
495            return Err(Errno::EINVAL.into());
496        };
497
498        if socket_ref.status().contains(FileStatus::NONBLOCK) {
499            flags.insert(MsgFlags::MSG_DONTWAIT);
500        }
501
502        let len: libc::size_t = args.iovs.iter().map(|x| x.len).sum();
503
504        // run in a closure so that an early return doesn't skip checking if we should block
505        let result = (|| {
506            // a temporary location to store the message and header if we popped them
507            let message_storage;
508            let header_storage;
509
510            let (message, header) = if !flags.contains(MsgFlags::MSG_PEEK) {
511                // pop the message from the receive buffer
512                (message_storage, header_storage) = socket_ref
513                    .recv_buffer
514                    .pop_message()
515                    .ok_or(Errno::EWOULDBLOCK)?;
516                (&message_storage, &header_storage)
517            } else {
518                // peek the message from the receive buffer
519                let (message, header) = socket_ref
520                    .recv_buffer
521                    .peek_message()
522                    .ok_or(Errno::EWOULDBLOCK)?;
523                (message, header)
524            };
525
526            // truncate the payload if the payload is larger than the user-provided buffers
527            let truncated_message = &message[..std::cmp::min(len, message.len())];
528
529            // write the truncated message to the iovs
530            let mut writer = IoVecWriter::new(args.iovs, mem);
531            writer
532                .write_all(truncated_message)
533                .map_err(|e| Errno::try_from(e).unwrap())?;
534
535            let return_val = if flags.contains(MsgFlags::MSG_TRUNC) {
536                message.len()
537            } else {
538                // the number of bytes written
539                truncated_message.len()
540            };
541
542            let mut return_flags = MsgFlags::empty();
543            return_flags.set(MsgFlags::MSG_TRUNC, truncated_message.len() < message.len());
544
545            // update the cache of the last recv time
546            socket_ref.recv_time_of_last_read_packet = Some(header.recv_time);
547
548            Ok(RecvmsgReturn {
549                return_val: return_val.try_into().unwrap(),
550                addr: Some(header.src.into()),
551                msg_flags: return_flags.bits(),
552                control_len: 0,
553            })
554        })();
555
556        socket_ref.refresh_readable_writable(FileSignals::empty(), cb_queue);
557
558        // if the syscall would block and we don't have the MSG_DONTWAIT flag
559        if result.as_ref().err() == Some(&Errno::EWOULDBLOCK)
560            && !flags.contains(MsgFlags::MSG_DONTWAIT)
561        {
562            // if the syscall would block but the file's reading has been shut down, return EOF
563            if socket_ref.shutdown_status.contains(ShutdownFlags::READ) {
564                return Ok(RecvmsgReturn {
565                    return_val: 0,
566                    addr: None,
567                    msg_flags: 0,
568                    control_len: 0,
569                });
570            }
571
572            return Err(SyscallError::new_blocked_on_file(
573                File::Socket(Socket::Inet(InetSocket::Udp(socket.clone()))),
574                FileState::READABLE,
575                socket_ref.supports_sa_restart(),
576            ));
577        }
578
579        Ok(result?)
580    }
581
582    pub fn ioctl(
583        &mut self,
584        request: IoctlRequest,
585        arg_ptr: ForeignPtr<()>,
586        mem: &mut MemoryManager,
587    ) -> SyscallResult {
588        match request {
589            // equivalent to SIOCINQ
590            IoctlRequest::FIONREAD => {
591                let len = self
592                    .recv_buffer
593                    .peek_message()
594                    .map(|m| m.0.len())
595                    .unwrap_or(0)
596                    .try_into()
597                    .unwrap();
598
599                let arg_ptr = arg_ptr.cast::<libc::c_int>();
600                mem.write(arg_ptr, &len)?;
601
602                Ok(0.into())
603            }
604            // equivalent to SIOCOUTQ
605            IoctlRequest::TIOCOUTQ => {
606                let len = self.send_buffer.len_bytes().try_into().unwrap();
607
608                let arg_ptr = arg_ptr.cast::<libc::c_int>();
609                mem.write(arg_ptr, &len)?;
610
611                Ok(0.into())
612            }
613            IoctlRequest::SIOCGSTAMP => {
614                // socket(7): "Return a struct timeval with the receive timestamp of the last packet
615                // passed to the user. [...] This ioctl should only be used if the socket option
616                // SO_TIMESTAMP is not set on the socket. Otherwise, it returns the timestamp of the
617                // last packet that was received while SO_TIMESTAMP was not set, or it fails if no
618                // such packet has been received, (i.e., ioctl(2) returns -1 with errno set to
619                // ENOENT)."
620                let Some(last_recv_time) = self.recv_time_of_last_read_packet else {
621                    return Err(Errno::ENOENT.into());
622                };
623
624                let last_recv_time = (last_recv_time - EmulatedTime::UNIX_EPOCH)
625                    .try_into()
626                    .unwrap();
627
628                let arg_ptr = arg_ptr.cast::<libc::timeval>();
629                mem.write(arg_ptr, &last_recv_time)?;
630
631                Ok(0.into())
632            }
633            IoctlRequest::FIONBIO => {
634                panic!("This should have been handled by the ioctl syscall handler");
635            }
636            IoctlRequest::TCGETS
637            | IoctlRequest::TCSETS
638            | IoctlRequest::TCSETSW
639            | IoctlRequest::TCSETSF
640            | IoctlRequest::TCGETA
641            | IoctlRequest::TCSETA
642            | IoctlRequest::TCSETAW
643            | IoctlRequest::TCSETAF
644            | IoctlRequest::TIOCGWINSZ
645            | IoctlRequest::TIOCSWINSZ => {
646                // not a terminal
647                Err(Errno::ENOTTY.into())
648            }
649            request => {
650                warn_once_then_debug!(
651                    "We do not yet handle ioctl request {request:?} on tcp sockets"
652                );
653                Err(Errno::EINVAL.into())
654            }
655        }
656    }
657
658    pub fn stat(&self) -> Result<linux_api::stat::stat, SyscallError> {
659        warn_once_then_debug!("We do not yet handle stat calls on udp sockets");
660        Err(Errno::EINVAL.into())
661    }
662
663    pub fn listen(
664        _socket: &Arc<AtomicRefCell<Self>>,
665        _backlog: i32,
666        _net_ns: &NetworkNamespace,
667        _rng: impl rand::Rng,
668        _cb_queue: &mut CallbackQueue,
669    ) -> Result<(), Errno> {
670        Err(Errno::EOPNOTSUPP)
671    }
672
673    pub fn connect(
674        socket: &Arc<AtomicRefCell<Self>>,
675        peer_addr: &SockaddrStorage,
676        net_ns: &NetworkNamespace,
677        rng: impl rand::Rng,
678        _cb_queue: &mut CallbackQueue,
679    ) -> Result<(), SyscallError> {
680        // if not an inet socket address
681        // TODO: handle an AF_UNSPEC socket address
682        let Some(peer_addr) = peer_addr.as_inet() else {
683            return Err(Errno::EINVAL.into());
684        };
685
686        let mut peer_addr: std::net::SocketAddrV4 = (*peer_addr).into();
687
688        // https://stackoverflow.com/a/22425796
689        if peer_addr.ip().is_unspecified() {
690            peer_addr.set_ip(std::net::Ipv4Addr::LOCALHOST);
691        }
692
693        // NOTE: it would be nice to use `Ipv4Addr::is_loopback` in this code rather than comparing
694        // to `Ipv4Addr::LOCALHOST`, but the rest of Shadow probably can't handle other loopback
695        // addresses (ex: 127.0.0.2) and it's probably best not to change this behaviour
696
697        // make sure we will be able to route this later
698        // TODO: UDP sockets probably shouldn't return `ECONNREFUSED`
699        if peer_addr.ip() != &std::net::Ipv4Addr::LOCALHOST {
700            let is_routable =
701                Worker::is_routable(net_ns.default_ip.into(), (*peer_addr.ip()).into());
702
703            if !is_routable {
704                // can't route it - there is no node with this address
705                log::warn!(
706                    "Attempting to connect to address '{peer_addr}' for which no host exists"
707                );
708                return Err(Errno::ECONNREFUSED.into());
709            }
710        }
711
712        // make sure that we're bound
713        {
714            let mut socket_ref = socket.borrow_mut();
715
716            if let Some(bound_addr) = socket_ref.bound_addr {
717                // we must have an association since we're bound
718                assert!(socket_ref.association.is_some());
719
720                // make sure the new peer address is connectable from the bound interface
721                if !bound_addr.ip().is_unspecified() {
722                    // assume that a socket bound to 0.0.0.0 can connect anywhere, so only check
723                    // localhost
724                    match (
725                        bound_addr.ip() == &Ipv4Addr::LOCALHOST,
726                        peer_addr.ip() == &Ipv4Addr::LOCALHOST,
727                    ) {
728                        // bound and peer on loopback interface
729                        (true, true) => {}
730                        // neither bound nor peer on loopback interface (shadow treats any
731                        // non-127.0.0.1 address as an "internet" address)
732                        (false, false) => {}
733                        _ => return Err(Errno::EINVAL.into()),
734                    }
735                }
736            } else {
737                // we can't be unbound but have a peer
738                assert!(socket_ref.peer_addr.is_none());
739                assert!(socket_ref.association.is_none());
740
741                // implicit bind (use default interface unless the remote peer is on loopback)
742                let local_addr = if peer_addr.ip() == &std::net::Ipv4Addr::LOCALHOST {
743                    SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)
744                } else {
745                    SocketAddrV4::new(net_ns.default_ip, 0)
746                };
747
748                // this will allow us to receive packets from any source address, but
749                // `push_in_packet` should drop any packets that aren't from the peer
750                let unspecified_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
751
752                let (local_addr, handle) = super::associate_socket(
753                    InetSocket::Udp(Arc::clone(socket)),
754                    local_addr,
755                    unspecified_addr,
756                    /* check_generic_peer= */ true,
757                    net_ns,
758                    rng,
759                )?;
760
761                socket_ref.bound_addr = Some(local_addr);
762                socket_ref.association = Some(handle);
763            }
764
765            socket_ref.peer_addr = Some(peer_addr);
766        }
767
768        Ok(())
769    }
770
771    pub fn accept(
772        &mut self,
773        _net_ns: &NetworkNamespace,
774        _rng: impl rand::Rng,
775        _cb_queue: &mut CallbackQueue,
776    ) -> Result<OpenFile, SyscallError> {
777        Err(Errno::EOPNOTSUPP.into())
778    }
779
780    pub fn shutdown(
781        &mut self,
782        how: Shutdown,
783        _cb_queue: &mut CallbackQueue,
784    ) -> Result<(), SyscallError> {
785        // TODO: what if we set a peer, then unset the peer, then call shutdown?
786        if self.peer_addr.is_none() {
787            return Err(Errno::ENOTCONN.into());
788        }
789
790        if how == Shutdown::SHUT_WR || how == Shutdown::SHUT_RDWR {
791            // writing has been shut down
792            self.shutdown_status.insert(ShutdownFlags::WRITE)
793        }
794
795        if how == Shutdown::SHUT_RD || how == Shutdown::SHUT_RDWR {
796            // reading has been shut down
797            self.shutdown_status.insert(ShutdownFlags::READ)
798        }
799
800        Ok(())
801    }
802
803    pub fn getsockopt(
804        &mut self,
805        level: libc::c_int,
806        optname: libc::c_int,
807        optval_ptr: ForeignPtr<()>,
808        optlen: libc::socklen_t,
809        mem: &mut MemoryManager,
810        _cb_queue: &mut CallbackQueue,
811    ) -> Result<libc::socklen_t, SyscallError> {
812        match (level, optname) {
813            (libc::SOL_SOCKET, libc::SO_SNDBUF) => {
814                let sndbuf_size = self.send_buffer.soft_limit_bytes().try_into().unwrap();
815
816                let optval_ptr = optval_ptr.cast::<libc::c_int>();
817                let bytes_written = write_partial(mem, &sndbuf_size, optval_ptr, optlen as usize)?;
818
819                Ok(bytes_written as libc::socklen_t)
820            }
821            (libc::SOL_SOCKET, libc::SO_RCVBUF) => {
822                let rcvbuf_size = self.recv_buffer.soft_limit_bytes().try_into().unwrap();
823
824                let optval_ptr = optval_ptr.cast::<libc::c_int>();
825                let bytes_written = write_partial(mem, &rcvbuf_size, optval_ptr, optlen as usize)?;
826
827                Ok(bytes_written as libc::socklen_t)
828            }
829            (libc::SOL_SOCKET, libc::SO_ERROR) => {
830                let error = 0;
831
832                let optval_ptr = optval_ptr.cast::<libc::c_int>();
833                let bytes_written = write_partial(mem, &error, optval_ptr, optlen as usize)?;
834
835                Ok(bytes_written as libc::socklen_t)
836            }
837            (libc::SOL_SOCKET, libc::SO_DOMAIN) => {
838                let domain = libc::AF_INET;
839
840                let optval_ptr = optval_ptr.cast::<libc::c_int>();
841                let bytes_written = write_partial(mem, &domain, optval_ptr, optlen as usize)?;
842
843                Ok(bytes_written as libc::socklen_t)
844            }
845            (libc::SOL_SOCKET, libc::SO_TYPE) => {
846                let sock_type = libc::SOCK_DGRAM;
847
848                let optval_ptr = optval_ptr.cast::<libc::c_int>();
849                let bytes_written = write_partial(mem, &sock_type, optval_ptr, optlen as usize)?;
850
851                Ok(bytes_written as libc::socklen_t)
852            }
853            (libc::SOL_SOCKET, libc::SO_PROTOCOL) => {
854                let protocol = libc::IPPROTO_UDP;
855
856                let optval_ptr = optval_ptr.cast::<libc::c_int>();
857                let bytes_written = write_partial(mem, &protocol, optval_ptr, optlen as usize)?;
858
859                Ok(bytes_written as libc::socklen_t)
860            }
861            (libc::SOL_SOCKET, libc::SO_ACCEPTCONN) => {
862                let optval_ptr = optval_ptr.cast::<libc::c_int>();
863                let bytes_written = write_partial(mem, &0, optval_ptr, optlen as usize)?;
864
865                Ok(bytes_written as libc::socklen_t)
866            }
867            (libc::SOL_SOCKET, libc::SO_BROADCAST) => {
868                let optval_ptr = optval_ptr.cast::<libc::c_int>();
869                // we don't support broadcast sockets, so just just return the default 0
870                let bytes_written = write_partial(mem, &0, optval_ptr, optlen as usize)?;
871
872                Ok(bytes_written as libc::socklen_t)
873            }
874            (libc::SOL_SOCKET, _) => {
875                log_once_per_value_at_level!(
876                    (level, optname),
877                    (i32, i32),
878                    log::Level::Warn,
879                    log::Level::Debug,
880                    "getsockopt called with unsupported level {level} and opt {optname}"
881                );
882                Err(Errno::ENOPROTOOPT.into())
883            }
884            _ => {
885                log_once_per_value_at_level!(
886                    (level, optname),
887                    (i32, i32),
888                    log::Level::Warn,
889                    log::Level::Debug,
890                    "getsockopt called with unsupported level {level} and opt {optname}"
891                );
892                Err(Errno::EOPNOTSUPP.into())
893            }
894        }
895    }
896
897    pub fn setsockopt(
898        &mut self,
899        level: libc::c_int,
900        optname: libc::c_int,
901        optval_ptr: ForeignPtr<()>,
902        optlen: libc::socklen_t,
903        mem: &MemoryManager,
904    ) -> Result<(), SyscallError> {
905        match (level, optname) {
906            (libc::SOL_SOCKET, libc::SO_SNDBUF) => {
907                type OptType = libc::c_int;
908
909                if usize::try_from(optlen).unwrap() < std::mem::size_of::<OptType>() {
910                    return Err(Errno::EINVAL.into());
911                }
912
913                let optval_ptr = optval_ptr.cast::<OptType>();
914                let val: u64 = mem.read(optval_ptr)?.try_into().or(Err(Errno::EINVAL))?;
915
916                // linux kernel doubles this value upon setting
917                let val = val * 2;
918
919                // Linux also has limits SOCK_MIN_SNDBUF (slightly greater than 4096) and the sysctl
920                // max limit. We choose a reasonable lower limit for Shadow. The minimum limit in
921                // man 7 socket is incorrect.
922                let val = std::cmp::max(val, 4096);
923
924                // This upper limit was added as an arbitrarily high number so that we don't change
925                // Shadow's behaviour, but also prevents an application from setting this to
926                // something unnecessarily large like INT_MAX.
927                let val = std::cmp::min(val, 268435456); // 2^28 = 256 MiB
928
929                self.send_buffer
930                    .set_soft_limit_bytes(val.try_into().unwrap());
931            }
932            (libc::SOL_SOCKET, libc::SO_RCVBUF) => {
933                type OptType = libc::c_int;
934
935                if usize::try_from(optlen).unwrap() < std::mem::size_of::<OptType>() {
936                    return Err(Errno::EINVAL.into());
937                }
938
939                let optval_ptr = optval_ptr.cast::<OptType>();
940                let val: u64 = mem.read(optval_ptr)?.try_into().or(Err(Errno::EINVAL))?;
941
942                // linux kernel doubles this value upon setting
943                let val = val * 2;
944
945                // Linux also has limits SOCK_MIN_RCVBUF (slightly greater than 2048) and the sysctl
946                // max limit. We choose a reasonable lower limit for Shadow. The minimum limit in
947                // man 7 socket is incorrect.
948                let val = std::cmp::max(val, 2048);
949
950                // This upper limit was added as an arbitrarily high number so that we don't change
951                // Shadow's behaviour, but also prevents an application from setting this to
952                // something unnecessarily large like INT_MAX.
953                let val = std::cmp::min(val, 268435456); // 2^28 = 256 MiB
954
955                self.recv_buffer
956                    .set_soft_limit_bytes(val.try_into().unwrap());
957            }
958            (libc::SOL_SOCKET, libc::SO_REUSEADDR) => {
959                // TODO: implement this
960                warn_once_then_debug!("setsockopt SO_REUSEADDR not yet implemented for udp");
961                return Err(Errno::ENOPROTOOPT.into());
962            }
963            (libc::SOL_SOCKET, libc::SO_REUSEPORT) => {
964                // TODO: implement this
965                warn_once_then_debug!("setsockopt SO_REUSEPORT not yet implemented for udp");
966                return Err(Errno::ENOPROTOOPT.into());
967            }
968            (libc::SOL_SOCKET, libc::SO_KEEPALIVE) => {
969                // TODO: implement this
970                warn_once_then_debug!("setsockopt SO_KEEPALIVE not yet implemented for udp");
971                return Err(Errno::ENOPROTOOPT.into());
972            }
973            (libc::SOL_SOCKET, libc::SO_BROADCAST) => {
974                type OptType = libc::c_int;
975
976                if usize::try_from(optlen).unwrap() < std::mem::size_of::<OptType>() {
977                    return Err(Errno::EINVAL.into());
978                }
979
980                let optval_ptr = optval_ptr.cast::<OptType>();
981                let val = mem.read(optval_ptr)?;
982
983                if val == 0 {
984                    // we don't support broadcast sockets, so an attempt to disable is okay
985                } else {
986                    // TODO: implement this, pkg.go.dev/net uses it
987                    warn_once_then_debug!(
988                        "setsockopt SO_BROADCAST not yet implemented for udp; ignoring and returning 0"
989                    );
990                }
991            }
992            _ => {
993                log_once_per_value_at_level!(
994                    (level, optname),
995                    (i32, i32),
996                    log::Level::Warn,
997                    log::Level::Debug,
998                    "setsockopt called with unsupported level {level} and opt {optname}"
999                );
1000                return Err(Errno::ENOPROTOOPT.into());
1001            }
1002        }
1003
1004        Ok(())
1005    }
1006
1007    pub fn add_listener(
1008        &mut self,
1009        monitoring_state: FileState,
1010        monitoring_signals: FileSignals,
1011        filter: StateListenerFilter,
1012        notify_fn: impl Fn(FileState, FileState, FileSignals, &mut CallbackQueue)
1013        + Send
1014        + Sync
1015        + 'static,
1016    ) -> StateListenHandle {
1017        self.event_source
1018            .add_listener(monitoring_state, monitoring_signals, filter, notify_fn)
1019    }
1020
1021    pub fn add_legacy_listener(&mut self, ptr: HostTreePointer<c::StatusListener>) {
1022        self.event_source.add_legacy_listener(ptr);
1023    }
1024
1025    pub fn remove_legacy_listener(&mut self, ptr: *mut c::StatusListener) {
1026        self.event_source.remove_legacy_listener(ptr);
1027    }
1028
1029    pub fn state(&self) -> FileState {
1030        self.state
1031    }
1032
1033    fn refresh_readable_writable(&mut self, signals: FileSignals, cb_queue: &mut CallbackQueue) {
1034        let readable = !self.recv_buffer.is_empty();
1035        let writable = self.send_buffer.has_space();
1036
1037        let readable = readable.then_some(FileState::READABLE).unwrap_or_default();
1038        let writable = writable.then_some(FileState::WRITABLE).unwrap_or_default();
1039
1040        self.update_state(
1041            /* mask= */ FileState::READABLE | FileState::WRITABLE,
1042            readable | writable,
1043            signals,
1044            cb_queue,
1045        );
1046    }
1047
1048    fn update_state(
1049        &mut self,
1050        mask: FileState,
1051        state: FileState,
1052        signals: FileSignals,
1053        cb_queue: &mut CallbackQueue,
1054    ) {
1055        let old_state = self.state;
1056
1057        // remove the masked flags, then copy the masked flags
1058        self.state.remove(mask);
1059        self.state.insert(state & mask);
1060
1061        self.handle_state_change(old_state, signals, cb_queue);
1062    }
1063
1064    fn handle_state_change(
1065        &mut self,
1066        old_state: FileState,
1067        signals: FileSignals,
1068        cb_queue: &mut CallbackQueue,
1069    ) {
1070        let states_changed = self.state ^ old_state;
1071
1072        // if nothing changed
1073        if states_changed.is_empty() && signals.is_empty() {
1074            return;
1075        }
1076
1077        self.event_source
1078            .notify_listeners(self.state, states_changed, signals, cb_queue);
1079    }
1080}
1081
1082/// Non-payload data for a message in the send buffer.
1083#[derive(Debug)]
1084struct MessageSendHeader {
1085    /// The source address (typically the bind address). The application can theoretically use
1086    /// `IP_PKTINFO` to set a per-message source address.
1087    src: SocketAddrV4,
1088    /// The destination address (for example the peer).
1089    dst: SocketAddrV4,
1090    /// The priority for the packet that we'll create in the future, given to us by the host.
1091    packet_priority: FifoPacketPriority,
1092}
1093
1094/// Non-payload data for a message in the receive buffer.
1095#[derive(Debug)]
1096struct MessageRecvHeader {
1097    /// The source address (for example the peer).
1098    src: SocketAddrV4,
1099    /// The destination address (typically the bind address). The application can theoretically use
1100    /// `IP_PKTINFO` to get the packet destination address.
1101    #[allow(dead_code)]
1102    dst: SocketAddrV4,
1103    /// The time when the network interface received the message.
1104    recv_time: EmulatedTime,
1105}
1106
1107/// A buffer of UDP messages and message headers.
1108#[derive(Debug)]
1109struct MessageBuffer<Hdr> {
1110    /// The message payloads and headers.
1111    // use a `LinkedList` so that socket buffers can shrink when they're empty (as opposed to
1112    // `VecDeque`)
1113    buffer: LinkedList<(Bytes, Hdr)>,
1114    /// The number of payload bytes in this socket.
1115    len_bytes: usize,
1116    /// A soft limit for the maximum number of payload bytes this buffer can hold.
1117    soft_limit_bytes: usize,
1118}
1119
1120impl<Hdr> MessageBuffer<Hdr> {
1121    pub fn new(soft_limit_bytes: usize) -> Self {
1122        Self {
1123            buffer: std::collections::LinkedList::new(),
1124            len_bytes: 0,
1125            soft_limit_bytes,
1126        }
1127    }
1128
1129    /// Push a message to the buffer. Returns the message and header as an `Err` if there wasn't
1130    /// enough space.
1131    pub fn push_message(&mut self, message: Bytes, header: Hdr) -> Result<(), (Bytes, Hdr)> {
1132        // TODO: i think udp allows at most one packet to exceed the buffer capacity; should confirm
1133        // this
1134        if !self.has_space() {
1135            return Err((message, header));
1136        }
1137
1138        // TODO: on linux the socket buffer length also takes into account any header and struct
1139        // overhead, otherwise the buffer would take an infinite amount of 0-len packets
1140        self.len_bytes += message.len();
1141        self.buffer.push_back((message, header));
1142
1143        Ok(())
1144    }
1145
1146    /// Pop the next message from the buffer. Returns a tuple of the message bytes and message
1147    /// header.
1148    pub fn pop_message(&mut self) -> Option<(Bytes, Hdr)> {
1149        let (message, header) = self.buffer.pop_front()?;
1150        self.len_bytes -= message.len();
1151
1152        Some((message, header))
1153    }
1154
1155    /// Peek the next message in the buffer.
1156    pub fn peek_message(&self) -> Option<&(Bytes, Hdr)> {
1157        self.buffer.front()
1158    }
1159
1160    /// The number of payload bytes contained in the buffer. A length of 0 does not mean that the
1161    /// buffer is empty.
1162    pub fn len_bytes(&self) -> usize {
1163        self.len_bytes
1164    }
1165
1166    /// Is there space for at least one more packet?
1167    pub fn has_space(&self) -> bool {
1168        self.len_bytes < self.soft_limit_bytes
1169    }
1170
1171    /// Is the buffer empty (does it have 0 packets)?
1172    pub fn is_empty(&self) -> bool {
1173        self.buffer.is_empty()
1174    }
1175
1176    /// The soft limit for the size of the buffer.
1177    pub fn soft_limit_bytes(&self) -> usize {
1178        self.soft_limit_bytes
1179    }
1180
1181    /// Set the soft limit for the size of the buffer.
1182    pub fn set_soft_limit_bytes(&mut self, soft_limit_bytes: usize) {
1183        self.soft_limit_bytes = soft_limit_bytes;
1184    }
1185}