shadow_rs/host/descriptor/socket/inet/
tcp.rs

1use std::net::{Ipv4Addr, SocketAddrV4};
2use std::sync::{Arc, Weak};
3
4use atomic_refcell::AtomicRefCell;
5use linux_api::errno::Errno;
6use linux_api::ioctls::IoctlRequest;
7use linux_api::socket::Shutdown;
8use nix::sys::socket::{MsgFlags, SockaddrIn};
9use shadow_shim_helper_rs::emulated_time::EmulatedTime;
10use shadow_shim_helper_rs::simulation_time::SimulationTime;
11use shadow_shim_helper_rs::syscall_types::ForeignPtr;
12
13use crate::core::work::task::TaskRef;
14use crate::core::worker::Worker;
15use crate::cshadow as c;
16use crate::host::descriptor::listener::{StateEventSource, StateListenHandle, StateListenerFilter};
17use crate::host::descriptor::socket::inet;
18use crate::host::descriptor::socket::{InetSocket, RecvmsgArgs, RecvmsgReturn, SendmsgArgs};
19use crate::host::descriptor::{File, Socket};
20use crate::host::descriptor::{
21    FileMode, FileSignals, FileState, FileStatus, OpenFile, SyscallResult,
22};
23use crate::host::memory_manager::MemoryManager;
24use crate::host::network::interface::FifoPacketPriority;
25use crate::host::network::namespace::{AssociationHandle, NetworkNamespace};
26use crate::host::syscall::io::{IoVec, IoVecReader, IoVecWriter, write_partial};
27use crate::host::syscall::types::SyscallError;
28use crate::network::packet::{PacketRc, PacketStatus};
29use crate::utility::callback_queue::CallbackQueue;
30use crate::utility::sockaddr::SockaddrStorage;
31use crate::utility::{HostTreePointer, ObjectCounter};
32
33pub struct TcpSocket {
34    tcp_state: tcp::TcpState<TcpDeps>,
35    socket_weak: Weak<AtomicRefCell<Self>>,
36    event_source: StateEventSource,
37    status: FileStatus,
38    file_state: FileState,
39    association: Option<AssociationHandle>,
40    connect_result_is_pending: bool,
41    shutdown_status: Option<Shutdown>,
42    // should only be used by `OpenFile` to make sure there is only ever one `OpenFile` instance for
43    // this file
44    has_open_file: bool,
45    _counter: ObjectCounter,
46}
47
48impl TcpSocket {
49    pub fn new(status: FileStatus) -> Arc<AtomicRefCell<Self>> {
50        let rv = Arc::new_cyclic(|weak: &Weak<AtomicRefCell<Self>>| {
51            let tcp_dependencies = TcpDeps {
52                timer_state: Arc::new(AtomicRefCell::new(TcpDepsTimerState {
53                    socket: weak.clone(),
54                    registered_by: tcp::TimerRegisteredBy::Parent,
55                })),
56            };
57
58            AtomicRefCell::new(Self {
59                tcp_state: tcp::TcpState::new(tcp_dependencies, tcp::TcpConfig::default()),
60                socket_weak: weak.clone(),
61                event_source: StateEventSource::new(),
62                status,
63                // the readable/writable file state shouldn't matter here since we run
64                // `with_tcp_state` below to update it, but we need ACTIVE set so that epoll works
65                file_state: FileState::ACTIVE,
66                association: None,
67                connect_result_is_pending: false,
68                shutdown_status: None,
69                has_open_file: false,
70                _counter: ObjectCounter::new("TcpSocket"),
71            })
72        });
73
74        // run a no-op function on the state, which will force the socket to update its file state
75        // to match the tcp state
76        CallbackQueue::queue_and_run_with_legacy(|cb_queue| {
77            rv.borrow_mut().with_tcp_state(cb_queue, |_state| ())
78        });
79
80        rv
81    }
82
83    pub fn status(&self) -> FileStatus {
84        self.status
85    }
86
87    pub fn set_status(&mut self, status: FileStatus) {
88        self.status = status;
89    }
90
91    pub fn mode(&self) -> FileMode {
92        FileMode::READ | FileMode::WRITE
93    }
94
95    pub fn has_open_file(&self) -> bool {
96        self.has_open_file
97    }
98
99    pub fn supports_sa_restart(&self) -> bool {
100        true
101    }
102
103    pub fn set_has_open_file(&mut self, val: bool) {
104        self.has_open_file = val;
105    }
106
107    fn with_tcp_state<T>(
108        &mut self,
109        cb_queue: &mut CallbackQueue,
110        f: impl FnOnce(&mut tcp::TcpState<TcpDeps>) -> T,
111    ) -> T {
112        self.with_tcp_state_and_signal(cb_queue, |state| (f(state), FileSignals::empty()))
113    }
114
115    /// Update the current tcp state. The tcp state should only ever be updated through this method.
116    fn with_tcp_state_and_signal<T>(
117        &mut self,
118        cb_queue: &mut CallbackQueue,
119        f: impl FnOnce(&mut tcp::TcpState<TcpDeps>) -> (T, FileSignals),
120    ) -> T {
121        let rv = f(&mut self.tcp_state);
122
123        // we may have mutated the tcp state, so update the socket's file state and notify listeners
124
125        // if there are packets to send, notify the host
126        if self.tcp_state.wants_to_send() {
127            // The upgrade could fail if this was run during a drop, or if some outer code decided
128            // to take the `TcpSocket` out of the `Arc` for some reason. Might as well panic since
129            // it might indicate a bug somewhere else.
130            let socket = self.socket_weak.upgrade().unwrap();
131
132            // First try getting our IP address from the tcp state (if it's connected), then try
133            // from the association handle (if it's not connected but is bound). Assume that our IP
134            // address will match an interface's IP address.
135            let interface_ip = *self
136                .tcp_state
137                .local_remote_addrs()
138                .map(|x| x.0)
139                .or(self.association.as_ref().map(|x| x.local_addr()))
140                .unwrap()
141                .ip();
142
143            cb_queue.add(move |_cb_queue| {
144                Worker::with_active_host(|host| {
145                    let socket = InetSocket::Tcp(socket);
146                    host.notify_socket_has_packets(interface_ip, &socket);
147                })
148                .unwrap();
149            });
150        }
151
152        // the following mappings from `PollState` to `FileState` may be relied on by other parts of
153        // the code, such as the `connect()` and `accept()` blocking behaviour, so be careful when
154        // making changes
155
156        let mut read_write_flags = FileState::empty();
157        let poll_state = self.tcp_state.poll();
158
159        if poll_state.intersects(tcp::PollState::READABLE | tcp::PollState::RECV_CLOSED) {
160            read_write_flags.insert(FileState::READABLE);
161        }
162        if poll_state.intersects(tcp::PollState::WRITABLE) {
163            read_write_flags.insert(FileState::WRITABLE);
164        }
165        if poll_state.intersects(tcp::PollState::READY_TO_ACCEPT) {
166            read_write_flags.insert(FileState::READABLE);
167        }
168        if poll_state.intersects(tcp::PollState::ERROR) {
169            read_write_flags.insert(FileState::READABLE | FileState::WRITABLE);
170        }
171
172        // if the socket/file is closed, undo all of the flags set above (closed sockets aren't
173        // readable or writable)
174        if self.file_state.contains(FileState::CLOSED) {
175            read_write_flags = FileState::empty();
176        }
177
178        // overwrite readable/writable flags
179        self.update_state(
180            FileState::READABLE | FileState::WRITABLE,
181            read_write_flags,
182            rv.1,
183            cb_queue,
184        );
185
186        // if the tcp state is in the closed state
187        if poll_state.contains(tcp::PollState::CLOSED) {
188            // drop the association handle so that we're removed from the network interface
189            self.association = None;
190            // we do not change to `FileState::CLOSED` here since that flag represents that the file
191            // has closed (with `close()`), not that the tcp state has closed
192        }
193
194        rv.0
195    }
196
197    pub fn push_in_packet(
198        &mut self,
199        packet: PacketRc,
200        cb_queue: &mut CallbackQueue,
201        _recv_time: EmulatedTime,
202    ) {
203        packet.add_status(PacketStatus::RcvSocketProcessed);
204
205        // TODO: don't bother copying the bytes if we know the push will fail
206
207        // TODO: we have no way of adding `PacketStatus::RcvSocketDropped` if the tcp state drops
208        // the packet
209
210        let header = packet
211            .ipv4_tcp_header()
212            .expect("TCP socket received a non-tcp packet");
213
214        // transfer the `Bytes` objects directly from the payload to the tcp state without copying
215        // the bytes themselves
216
217        let payload = tcp::Payload(packet.payload());
218        assert_eq!(payload.len() as usize, packet.payload_len());
219
220        self.with_tcp_state_and_signal(cb_queue, |s| {
221            let pushed_len = s.push_packet(&header, payload).unwrap();
222            let signals = if pushed_len > 0 {
223                FileSignals::READ_BUFFER_GREW
224            } else {
225                FileSignals::empty()
226            };
227            ((), signals)
228        });
229
230        packet.add_status(PacketStatus::RcvSocketBuffered);
231    }
232
233    pub fn pull_out_packet(&mut self, cb_queue: &mut CallbackQueue) -> Option<PacketRc> {
234        #[cfg(debug_assertions)]
235        let wants_to_send = self.tcp_state.wants_to_send();
236
237        // make sure that `self.has_data_to_send()` agrees with `tcp_state.wants_to_send()`
238        #[cfg(debug_assertions)]
239        debug_assert_eq!(self.has_data_to_send(), wants_to_send);
240
241        // pop a packet from the socket
242        let rv = self.with_tcp_state(cb_queue, |s| s.pop_packet());
243
244        let (header, payload) = match rv {
245            Ok(x) => x,
246            Err(tcp::PopPacketError::NoPacket) => {
247                #[cfg(debug_assertions)]
248                debug_assert!(!wants_to_send);
249                return None;
250            }
251            Err(tcp::PopPacketError::InvalidState) => {
252                #[cfg(debug_assertions)]
253                debug_assert!(!wants_to_send);
254                return None;
255            }
256        };
257
258        #[cfg(debug_assertions)]
259        debug_assert!(wants_to_send);
260
261        // We transfer the `Bytes` objects directly from the tcp state's `Payload` object to the
262        // packet without copying the bytes themselves.
263        // TODO: set packet priority?
264        let packet = PacketRc::new_ipv4_tcp(header, payload, 0);
265        packet.add_status(PacketStatus::SndCreated);
266
267        Some(packet)
268    }
269
270    pub fn peek_next_packet_priority(&self) -> Option<FifoPacketPriority> {
271        // TODO: support packet priorities?
272        self.has_data_to_send().then_some(0)
273    }
274
275    pub fn has_data_to_send(&self) -> bool {
276        self.tcp_state.wants_to_send()
277    }
278
279    pub fn getsockname(&self) -> Result<Option<SockaddrIn>, Errno> {
280        // The socket state won't always have the local address. For example if the socket was bound
281        // but connect() hasn't yet been called, the socket state will not have a local or remote
282        // address. Instead we should get the local address from the association.
283        Ok(Some(
284            self.association
285                .as_ref()
286                .map(|x| x.local_addr().into())
287                .unwrap_or(SockaddrIn::new(0, 0, 0, 0, 0)),
288        ))
289    }
290
291    pub fn getpeername(&self) -> Result<Option<SockaddrIn>, Errno> {
292        // The association won't always have the peer address. For example if the socket was bound
293        // before connect() was called, the association will have a peer of 0.0.0.0. Instead we
294        // should get the peer address from the socket state.
295        Ok(Some(
296            self.tcp_state
297                .local_remote_addrs()
298                .map(|x| x.1.into())
299                .ok_or(Errno::ENOTCONN)?,
300        ))
301
302        // TODO: This will not have the remote address once the tcp state has closed (for example by
303        // `shutdown(RDWR)`), in which case `local_remote_addrs()` will return `None` so this will
304        // incorrectly return ENOTCONN. Should fix this somehow and add a test.
305
306        // TODO: I don't think `getpeername()` should not return a valid peer name before the
307        // connection is successfully established.
308    }
309
310    pub fn address_family(&self) -> linux_api::socket::AddressFamily {
311        linux_api::socket::AddressFamily::AF_INET
312    }
313
314    pub fn close(&mut self, cb_queue: &mut CallbackQueue) -> Result<(), SyscallError> {
315        // we don't expect close() to ever have an error
316        self.with_tcp_state(cb_queue, |state| state.close())
317            .unwrap();
318
319        // add the closed flag and remove all other flags
320        self.update_state(
321            FileState::all(),
322            FileState::CLOSED,
323            FileSignals::empty(),
324            cb_queue,
325        );
326
327        Ok(())
328    }
329
330    pub fn bind(
331        socket: &Arc<AtomicRefCell<Self>>,
332        addr: Option<&SockaddrStorage>,
333        net_ns: &NetworkNamespace,
334        rng: impl rand::Rng,
335    ) -> Result<(), SyscallError> {
336        // if the address pointer was NULL
337        let Some(addr) = addr else {
338            return Err(Errno::EFAULT.into());
339        };
340
341        // if not an inet socket address
342        let Some(addr) = addr.as_inet() else {
343            return Err(Errno::EINVAL.into());
344        };
345
346        let addr: SocketAddrV4 = (*addr).into();
347
348        let mut socket_ref = socket.borrow_mut();
349
350        // if the socket is already associated
351        if socket_ref.association.is_some() {
352            return Err(Errno::EINVAL.into());
353        }
354
355        // this will allow us to receive packets from any peer
356        let peer_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
357
358        // associate the socket
359        let (_addr, handle) = inet::associate_socket(
360            InetSocket::Tcp(Arc::clone(socket)),
361            addr,
362            peer_addr,
363            /* check_generic_peer= */ true,
364            net_ns,
365            rng,
366        )?;
367
368        socket_ref.association = Some(handle);
369
370        Ok(())
371    }
372
373    pub fn readv(
374        &mut self,
375        _iovs: &[IoVec],
376        _offset: Option<libc::off_t>,
377        _flags: libc::c_int,
378        _mem: &mut MemoryManager,
379        _cb_queue: &mut CallbackQueue,
380    ) -> Result<libc::ssize_t, SyscallError> {
381        // we could call TcpSocket::recvmsg() here, but for now we expect that there are no code
382        // paths that would call TcpSocket::readv() since the readv() syscall handler should have
383        // called TcpSocket::recvmsg() instead
384        panic!("Called TcpSocket::readv() on a TCP socket");
385    }
386
387    pub fn writev(
388        &mut self,
389        _iovs: &[IoVec],
390        _offset: Option<libc::off_t>,
391        _flags: libc::c_int,
392        _mem: &mut MemoryManager,
393        _cb_queue: &mut CallbackQueue,
394    ) -> Result<libc::ssize_t, SyscallError> {
395        // we could call TcpSocket::sendmsg() here, but for now we expect that there are no code
396        // paths that would call TcpSocket::writev() since the writev() syscall handler should have
397        // called TcpSocket::sendmsg() instead
398        panic!("Called TcpSocket::writev() on a TCP socket");
399    }
400
401    pub fn sendmsg(
402        socket: &Arc<AtomicRefCell<Self>>,
403        args: SendmsgArgs,
404        mem: &mut MemoryManager,
405        _net_ns: &NetworkNamespace,
406        _rng: impl rand::Rng,
407        cb_queue: &mut CallbackQueue,
408    ) -> Result<libc::ssize_t, SyscallError> {
409        let mut socket_ref = socket.borrow_mut();
410
411        let Some(mut flags) = MsgFlags::from_bits(args.flags) else {
412            log::debug!("Unrecognized send flags: {:#b}", args.flags);
413            return Err(Errno::EINVAL.into());
414        };
415
416        if socket_ref.status().contains(FileStatus::NONBLOCK) {
417            flags.insert(MsgFlags::MSG_DONTWAIT);
418        }
419
420        let len: libc::size_t = args.iovs.iter().map(|x| x.len).sum();
421
422        // run in a closure so that an early return doesn't skip checking if we should block
423        let result = (|| {
424            let reader = IoVecReader::new(args.iovs, mem);
425
426            let rv = socket_ref.with_tcp_state(cb_queue, |state| state.send(reader, len));
427
428            let num_sent = match rv {
429                Ok(x) => x,
430                Err(tcp::SendError::Full) => return Err(Errno::EWOULDBLOCK),
431                Err(tcp::SendError::NotConnected) => return Err(Errno::EPIPE),
432                Err(tcp::SendError::StreamClosed) => return Err(Errno::EPIPE),
433                Err(tcp::SendError::Io(e)) => return Err(Errno::try_from(e).unwrap()),
434                Err(tcp::SendError::InvalidState) => return Err(Errno::EINVAL),
435            };
436
437            Ok(num_sent)
438        })();
439
440        // if the syscall would block and we don't have the MSG_DONTWAIT flag
441        if result == Err(Errno::EWOULDBLOCK) && !flags.contains(MsgFlags::MSG_DONTWAIT) {
442            return Err(SyscallError::new_blocked_on_file(
443                File::Socket(Socket::Inet(InetSocket::Tcp(socket.clone()))),
444                FileState::WRITABLE | FileState::CLOSED,
445                socket_ref.supports_sa_restart(),
446            ));
447        }
448
449        Ok(result?.try_into().unwrap())
450    }
451
452    pub fn recvmsg(
453        socket: &Arc<AtomicRefCell<Self>>,
454        args: RecvmsgArgs,
455        mem: &mut MemoryManager,
456        cb_queue: &mut CallbackQueue,
457    ) -> Result<RecvmsgReturn, SyscallError> {
458        let socket_ref = &mut *socket.borrow_mut();
459
460        // if there was an asynchronous error, return it
461        if let Some(error) = socket_ref.with_tcp_state(cb_queue, |state| state.clear_error()) {
462            // by returning this error, we're probably (but not necessarily) returning a previous
463            // connect() result
464            socket_ref.connect_result_is_pending = false;
465
466            return Err(tcp_error_to_errno(error).into());
467        }
468
469        let Some(mut flags) = MsgFlags::from_bits(args.flags) else {
470            log::debug!("Unrecognized recv flags: {:#b}", args.flags);
471            return Err(Errno::EINVAL.into());
472        };
473
474        if socket_ref.status().contains(FileStatus::NONBLOCK) {
475            flags.insert(MsgFlags::MSG_DONTWAIT);
476        }
477
478        let len: libc::size_t = args.iovs.iter().map(|x| x.len).sum();
479
480        // run in a closure so that an early return doesn't skip checking if we should block
481        let result = (|| {
482            let writer = IoVecWriter::new(args.iovs, mem);
483
484            let rv = socket_ref.with_tcp_state(cb_queue, |state| state.recv(writer, len));
485
486            let num_recv = match rv {
487                Ok(x) => x,
488                Err(tcp::RecvError::Empty) => {
489                    if [Shutdown::SHUT_RD, Shutdown::SHUT_RDWR]
490                        .map(Some)
491                        .contains(&socket_ref.shutdown_status)
492                    {
493                        0
494                    } else {
495                        return Err(Errno::EWOULDBLOCK);
496                    }
497                }
498                Err(tcp::RecvError::NotConnected) => return Err(Errno::ENOTCONN),
499                Err(tcp::RecvError::StreamClosed) => 0,
500                Err(tcp::RecvError::Io(e)) => return Err(Errno::try_from(e).unwrap()),
501                Err(tcp::RecvError::InvalidState) => return Err(Errno::EINVAL),
502            };
503
504            Ok(RecvmsgReturn {
505                return_val: num_recv.try_into().unwrap(),
506                addr: None,
507                msg_flags: MsgFlags::empty().bits(),
508                control_len: 0,
509            })
510        })();
511
512        // if the syscall would block and we don't have the MSG_DONTWAIT flag
513        if result.as_ref().err() == Some(&Errno::EWOULDBLOCK)
514            && !flags.contains(MsgFlags::MSG_DONTWAIT)
515        {
516            return Err(SyscallError::new_blocked_on_file(
517                File::Socket(Socket::Inet(InetSocket::Tcp(socket.clone()))),
518                FileState::READABLE | FileState::CLOSED,
519                socket_ref.supports_sa_restart(),
520            ));
521        }
522
523        Ok(result?)
524    }
525
526    pub fn ioctl(
527        &mut self,
528        _request: IoctlRequest,
529        _arg_ptr: ForeignPtr<()>,
530        _mem: &mut MemoryManager,
531    ) -> SyscallResult {
532        todo!();
533    }
534
535    pub fn stat(&self) -> Result<linux_api::stat::stat, SyscallError> {
536        warn_once_then_debug!("We do not yet handle stat calls on tcp sockets");
537        Err(Errno::EINVAL.into())
538    }
539
540    pub fn listen(
541        socket: &Arc<AtomicRefCell<Self>>,
542        backlog: i32,
543        net_ns: &NetworkNamespace,
544        rng: impl rand::Rng,
545        cb_queue: &mut CallbackQueue,
546    ) -> Result<(), Errno> {
547        let socket_ref = &mut *socket.borrow_mut();
548
549        // linux also makes this cast, so negative backlogs wrap around to large positive backlogs
550        // https://elixir.free-electrons.com/linux/v5.11.22/source/net/ipv4/af_inet.c#L212
551        let backlog = backlog as u32;
552
553        let is_associated = socket_ref.association.is_some();
554
555        let rv = if is_associated {
556            // if already associated, do nothing
557            let associate_fn = || Ok(None);
558            socket_ref.with_tcp_state(cb_queue, |state| state.listen(backlog, associate_fn))
559        } else {
560            // if not associated, associate and return the handle
561            let associate_fn = || {
562                // implicitly bind to all interfaces
563                let local_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
564
565                // want to receive packets from any address
566                let peer_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
567                let socket = Arc::clone(socket);
568
569                // associate the socket
570                let (_addr, handle) = inet::associate_socket(
571                    InetSocket::Tcp(Arc::clone(&socket)),
572                    local_addr,
573                    peer_addr,
574                    /* check_generic_peer= */ true,
575                    net_ns,
576                    rng,
577                )?;
578
579                Ok::<_, Errno>(Some(handle))
580            };
581            socket_ref.with_tcp_state(cb_queue, |state| state.listen(backlog, associate_fn))
582        };
583
584        let handle = match rv {
585            Ok(x) => x,
586            Err(tcp::ListenError::InvalidState) => return Err(Errno::EINVAL),
587            Err(tcp::ListenError::FailedAssociation(e)) => return Err(e),
588        };
589
590        // the `associate_fn` may or may not have run, so `handle` may or may not be set
591        if let Some(handle) = handle {
592            assert!(socket_ref.association.is_none());
593            socket_ref.association = Some(handle);
594        }
595
596        Ok(())
597    }
598
599    pub fn connect(
600        socket: &Arc<AtomicRefCell<Self>>,
601        peer_addr: &SockaddrStorage,
602        net_ns: &NetworkNamespace,
603        rng: impl rand::Rng,
604        cb_queue: &mut CallbackQueue,
605    ) -> Result<(), SyscallError> {
606        let socket_ref = &mut *socket.borrow_mut();
607
608        // if there was an asynchronous error, return it
609        if let Some(error) = socket_ref.with_tcp_state(cb_queue, |state| state.clear_error()) {
610            // by returning this error, we're probably (but not necessarily) returning a previous
611            // connect() result
612            socket_ref.connect_result_is_pending = false;
613
614            return Err(tcp_error_to_errno(error).into());
615        }
616
617        // if connect() had previously been called (either blocking or non-blocking), we need to
618        // return the result
619        if socket_ref.connect_result_is_pending {
620            // ignore all connect arguments and just check if we've connected
621
622            // check if it's still connecting (in the "syn-sent" or "syn-received" state)
623            if socket_ref
624                .tcp_state
625                .poll()
626                .contains(tcp::PollState::CONNECTING)
627            {
628                return Err(Errno::EALREADY.into());
629            }
630
631            // if not connecting and there were no socket errors (checked above)
632            socket_ref.connect_result_is_pending = false;
633            return Ok(());
634        }
635
636        // if not an inet socket address
637        let Some(peer_addr) = peer_addr.as_inet() else {
638            return Err(Errno::EINVAL.into());
639        };
640
641        let mut peer_addr: std::net::SocketAddrV4 = (*peer_addr).into();
642
643        // On Linux a connection to 0.0.0.0 means a connection to localhost:
644        // https://stackoverflow.com/a/22425796
645        if peer_addr.ip().is_unspecified() {
646            peer_addr.set_ip(std::net::Ipv4Addr::LOCALHOST);
647        }
648
649        let local_addr = socket_ref.association.as_ref().map(|x| x.local_addr());
650
651        let rv = if let Some(mut local_addr) = local_addr {
652            // the local address needs to be a specific address (this is normally what a routing
653            // table would figure out for us)
654            if local_addr.ip().is_unspecified() {
655                if peer_addr.ip() == &std::net::Ipv4Addr::LOCALHOST {
656                    local_addr.set_ip(Ipv4Addr::LOCALHOST)
657                } else {
658                    local_addr.set_ip(net_ns.default_ip)
659                };
660            }
661
662            // it's already associated so use the existing address
663            let associate_fn = || Ok((local_addr, None));
664            socket_ref.with_tcp_state(cb_queue, |state| state.connect(peer_addr, associate_fn))
665        } else {
666            // if not associated, associate and return the handle
667            let associate_fn = || {
668                // the local address needs to be a specific address (this is normally what a routing
669                // table would figure out for us)
670                let local_addr = if peer_addr.ip() == &std::net::Ipv4Addr::LOCALHOST {
671                    Ipv4Addr::LOCALHOST
672                } else {
673                    net_ns.default_ip
674                };
675
676                // add a wildcard port number
677                let local_addr = SocketAddrV4::new(local_addr, 0);
678
679                let (local_addr, handle) = inet::associate_socket(
680                    InetSocket::Tcp(Arc::clone(socket)),
681                    local_addr,
682                    peer_addr,
683                    /* check_generic_peer= */ true,
684                    net_ns,
685                    rng,
686                )?;
687
688                // use the actual local address that was assigned (will have port != 0)
689                Ok((local_addr, Some(handle)))
690            };
691            socket_ref.with_tcp_state(cb_queue, |state| state.connect(peer_addr, associate_fn))
692        };
693
694        let handle = match rv {
695            Ok(x) => x,
696            Err(tcp::ConnectError::InProgress) => return Err(Errno::EALREADY.into()),
697            Err(tcp::ConnectError::AlreadyConnected) => return Err(Errno::EISCONN.into()),
698            Err(tcp::ConnectError::IsListening) => return Err(Errno::EISCONN.into()),
699            Err(tcp::ConnectError::InvalidState) => return Err(Errno::EINVAL.into()),
700            Err(tcp::ConnectError::FailedAssociation(e)) => return Err(e),
701        };
702
703        // the `associate_fn` may not have associated the socket, so `handle` may or may not be set
704        if let Some(handle) = handle {
705            assert!(socket_ref.association.is_none());
706            socket_ref.association = Some(handle);
707        }
708
709        // we're attempting to connect, so set a flag so that we know a future connect() call should
710        // return the result
711        socket_ref.connect_result_is_pending = true;
712
713        if socket_ref.status.contains(FileStatus::NONBLOCK) {
714            Err(Errno::EINPROGRESS.into())
715        } else {
716            let err = SyscallError::new_blocked_on_file(
717                File::Socket(Socket::Inet(InetSocket::Tcp(Arc::clone(socket)))),
718                // I think we want this to resume when it leaves the "syn-sent" and "syn-received"
719                // states (for example moves to the "rst", "closed", "fin-wait-1", etc states).
720                //
721                // - READABLE: the state may timeout in the "syn-received" state and move to the
722                //   "closed" state, which is `tcp::PollState::RECV_CLOSED` and maps to
723                //   `FileState::READABLE`
724                // - WRITABLE: the state may reach the "established" state which is
725                //   `tcp::PollState::WRITABLE` which maps to `FileState::WRITABLE`
726                // - CLOSED: we use this just to be safe; typically the `connect()` syscall handler
727                //   would hold an `OpenFile` for this socket while the syscall is blocked which
728                //   would prevent the socket from being closed until the syscall completed
729                //
730                // We assume here that the "syn-sent" and "syn-received" states never have the
731                // `RECV_CLOSED`, `READABLE`, or `WRITABLE` `PollState` states, otherwise this
732                // syscall condition would trigger while the socket was still connecting. This all
733                // relies on the `PollState` to `FileState` mappings in `with_tcp_state()` above.
734                FileState::READABLE | FileState::WRITABLE | FileState::CLOSED,
735                socket_ref.supports_sa_restart(),
736            );
737
738            // block the current thread
739            Err(err)
740        }
741    }
742
743    pub fn accept(
744        &mut self,
745        net_ns: &NetworkNamespace,
746        rng: impl rand::Rng,
747        cb_queue: &mut CallbackQueue,
748    ) -> Result<OpenFile, SyscallError> {
749        let rv = self.with_tcp_state(cb_queue, |state| state.accept());
750
751        let accepted_state = match rv {
752            Ok(x) => x,
753            Err(tcp::AcceptError::InvalidState) => return Err(Errno::EINVAL.into()),
754            Err(tcp::AcceptError::NothingToAccept) => return Err(Errno::EAGAIN.into()),
755        };
756
757        let local_addr = accepted_state.local_addr();
758        let remote_addr = accepted_state.remote_addr();
759
760        // convert the accepted tcp state to a full tcp socket
761        let new_socket = Arc::new_cyclic(|weak: &Weak<AtomicRefCell<Self>>| {
762            let accepted_state = accepted_state.finalize(|deps| {
763                // update the timer state for new and existing pending timers to use the new
764                // accepted socket rather than the parent listening socket
765                let timer_state = &mut *deps.timer_state.borrow_mut();
766                timer_state.socket = weak.clone();
767                timer_state.registered_by = tcp::TimerRegisteredBy::Parent;
768            });
769
770            AtomicRefCell::new(Self {
771                tcp_state: accepted_state,
772                socket_weak: weak.clone(),
773                event_source: StateEventSource::new(),
774                status: FileStatus::empty(),
775                // the readable/writable file state shouldn't matter here since we run
776                // `with_tcp_state` below to update it, but we need ACTIVE set so that epoll works
777                file_state: FileState::ACTIVE,
778                association: None,
779                connect_result_is_pending: false,
780                shutdown_status: None,
781                has_open_file: false,
782                _counter: ObjectCounter::new("TcpSocket"),
783            })
784        });
785
786        // run a no-op function on the state, which will force the socket to update its file state
787        // to match the tcp state
788        new_socket
789            .borrow_mut()
790            .with_tcp_state(cb_queue, |_state| ());
791
792        // TODO: if the association fails, we lose the child socket
793
794        // associate the socket
795        let (_addr, handle) = inet::associate_socket(
796            InetSocket::Tcp(Arc::clone(&new_socket)),
797            local_addr,
798            remote_addr,
799            /* check_generic_peer= */ false,
800            net_ns,
801            rng,
802        )?;
803
804        new_socket.borrow_mut().association = Some(handle);
805
806        Ok(OpenFile::new(File::Socket(Socket::Inet(InetSocket::Tcp(
807            new_socket,
808        )))))
809    }
810
811    pub fn shutdown(
812        &mut self,
813        how: Shutdown,
814        cb_queue: &mut CallbackQueue,
815    ) -> Result<(), SyscallError> {
816        // Update `how` based on any previous shutdown() calls. For example if shutdown(RD) was
817        // previously called and now shutdown(WR) has been called, we should call shutdown(RDWR) on
818        // the tcp state.
819        let how = match (how, self.shutdown_status) {
820            // if it was previously `SHUT_RDWR`
821            (_, Some(Shutdown::SHUT_RDWR)) => Shutdown::SHUT_RDWR,
822            // if it's now `SHUT_RDWR`
823            (Shutdown::SHUT_RDWR, _) => Shutdown::SHUT_RDWR,
824            (Shutdown::SHUT_RD, None | Some(Shutdown::SHUT_RD)) => Shutdown::SHUT_RD,
825            (Shutdown::SHUT_RD, Some(Shutdown::SHUT_WR)) => Shutdown::SHUT_RDWR,
826            (Shutdown::SHUT_WR, None | Some(Shutdown::SHUT_WR)) => Shutdown::SHUT_WR,
827            (Shutdown::SHUT_WR, Some(Shutdown::SHUT_RD)) => Shutdown::SHUT_RDWR,
828        };
829
830        // Linux and the tcp library interpret shutdown flags differently. In the tcp library,
831        // `tcp::Shutdown` has a very specific meaning for `SHUT_RD` and `SHUT_WR`, whereas Linux is
832        // undocumented and not straightforward. Here we try to map from the Linux behaviour to the
833        // tcp library behaviour.
834        let tcp_how = match how {
835            Shutdown::SHUT_RD => None,
836            Shutdown::SHUT_WR => Some(tcp::Shutdown::Write),
837            Shutdown::SHUT_RDWR => Some(tcp::Shutdown::Both),
838        };
839
840        if let Some(tcp_how) = tcp_how {
841            if let Err(e) = self.with_tcp_state(cb_queue, |state| state.shutdown(tcp_how)) {
842                match e {
843                    tcp::ShutdownError::NotConnected => return Err(Errno::ENOTCONN.into()),
844                    tcp::ShutdownError::InvalidState => return Err(Errno::EINVAL.into()),
845                }
846            }
847        } else {
848            // we don't need to call shutdown() on the tcp state since we don't actually want to do
849            // anything, but we still need to return ENOTCONN sometimes
850
851            let not_connected = !self
852                .tcp_state
853                .poll()
854                .intersects(tcp::PollState::CONNECTING | tcp::PollState::CONNECTED);
855
856            if not_connected {
857                return Err(Errno::ENOTCONN.into());
858            }
859        }
860
861        // the shutdown was successful, so update our shutdown status
862        self.shutdown_status = Some(how);
863
864        Ok(())
865    }
866
867    pub fn getsockopt(
868        &mut self,
869        level: libc::c_int,
870        optname: libc::c_int,
871        optval_ptr: ForeignPtr<()>,
872        optlen: libc::socklen_t,
873        mem: &mut MemoryManager,
874        cb_queue: &mut CallbackQueue,
875    ) -> Result<libc::socklen_t, SyscallError> {
876        match (level, optname) {
877            (libc::SOL_SOCKET, libc::SO_ERROR) => {
878                // may update the socket's state (for example, reading `SO_ERROR` will make `poll()`
879                // stop returning `POLLERR` for the socket)
880                let error = self.with_tcp_state(cb_queue, |state| state.clear_error());
881                let error = error.map(tcp_error_to_errno).map(Into::into).unwrap_or(0);
882
883                let optval_ptr = optval_ptr.cast::<libc::c_int>();
884                let bytes_written = write_partial(mem, &error, optval_ptr, optlen as usize)?;
885
886                Ok(bytes_written as libc::socklen_t)
887            }
888            (libc::SOL_SOCKET, libc::SO_DOMAIN) => {
889                let domain = libc::AF_INET;
890
891                let optval_ptr = optval_ptr.cast::<libc::c_int>();
892                let bytes_written = write_partial(mem, &domain, optval_ptr, optlen as usize)?;
893
894                Ok(bytes_written as libc::socklen_t)
895            }
896            (libc::SOL_SOCKET, libc::SO_TYPE) => {
897                let sock_type = libc::SOCK_STREAM;
898
899                let optval_ptr = optval_ptr.cast::<libc::c_int>();
900                let bytes_written = write_partial(mem, &sock_type, optval_ptr, optlen as usize)?;
901
902                Ok(bytes_written as libc::socklen_t)
903            }
904            (libc::SOL_SOCKET, libc::SO_PROTOCOL) => {
905                let protocol = libc::IPPROTO_TCP;
906
907                let optval_ptr = optval_ptr.cast::<libc::c_int>();
908                let bytes_written = write_partial(mem, &protocol, optval_ptr, optlen as usize)?;
909
910                Ok(bytes_written as libc::socklen_t)
911            }
912            (libc::SOL_SOCKET, libc::SO_ACCEPTCONN) => {
913                let is_listener = self.tcp_state.poll().contains(tcp::PollState::LISTENING);
914                let is_listener = is_listener as libc::c_int;
915
916                let optval_ptr = optval_ptr.cast::<libc::c_int>();
917                let bytes_written = write_partial(mem, &is_listener, optval_ptr, optlen as usize)?;
918
919                Ok(bytes_written as libc::socklen_t)
920            }
921            (libc::SOL_SOCKET, libc::SO_BROADCAST) => {
922                let optval_ptr = optval_ptr.cast::<libc::c_int>();
923                // we don't support broadcast sockets, so just just return the default 0
924                let bytes_written = write_partial(mem, &0, optval_ptr, optlen as usize)?;
925
926                Ok(bytes_written as libc::socklen_t)
927            }
928            _ => {
929                log_once_per_value_at_level!(
930                    (level, optname),
931                    (i32, i32),
932                    log::Level::Warn,
933                    log::Level::Debug,
934                    "getsockopt called with unsupported level {level} and opt {optname}"
935                );
936                Err(Errno::ENOPROTOOPT.into())
937            }
938        }
939    }
940
941    pub fn setsockopt(
942        &mut self,
943        level: libc::c_int,
944        optname: libc::c_int,
945        optval_ptr: ForeignPtr<()>,
946        optlen: libc::socklen_t,
947        mem: &MemoryManager,
948    ) -> Result<(), SyscallError> {
949        match (level, optname) {
950            (libc::SOL_SOCKET, libc::SO_REUSEADDR) => {
951                // TODO: implement this, tor and tgen use it
952                log::trace!("setsockopt SO_REUSEADDR not yet implemented");
953            }
954            (libc::SOL_SOCKET, libc::SO_REUSEPORT) => {
955                // TODO: implement this, tgen uses it
956                log::trace!("setsockopt SO_REUSEPORT not yet implemented");
957            }
958            (libc::SOL_SOCKET, libc::SO_KEEPALIVE) => {
959                // TODO: implement this, libevent uses it in evconnlistener_new_bind()
960                log::trace!("setsockopt SO_KEEPALIVE not yet implemented");
961            }
962            (libc::SOL_SOCKET, libc::SO_BROADCAST) => {
963                type OptType = libc::c_int;
964
965                if usize::try_from(optlen).unwrap() < std::mem::size_of::<OptType>() {
966                    return Err(Errno::EINVAL.into());
967                }
968
969                let optval_ptr = optval_ptr.cast::<OptType>();
970                let val = mem.read(optval_ptr)?;
971
972                if val == 0 {
973                    // we don't support broadcast sockets, so an attempt to disable is okay
974                } else {
975                    // TODO: implement this, pkg.go.dev/net uses it
976                    warn_once_then_debug!(
977                        "setsockopt SO_BROADCAST not yet implemented for tcp; ignoring and returning 0"
978                    );
979                }
980            }
981            _ => {
982                log_once_per_value_at_level!(
983                    (level, optname),
984                    (i32, i32),
985                    log::Level::Warn,
986                    log::Level::Debug,
987                    "setsockopt called with unsupported level {level} and opt {optname}"
988                );
989                return Err(Errno::ENOPROTOOPT.into());
990            }
991        }
992
993        Ok(())
994    }
995
996    pub fn add_listener(
997        &mut self,
998        monitoring_state: FileState,
999        monitoring_signals: FileSignals,
1000        filter: StateListenerFilter,
1001        notify_fn: impl Fn(FileState, FileState, FileSignals, &mut CallbackQueue)
1002        + Send
1003        + Sync
1004        + 'static,
1005    ) -> StateListenHandle {
1006        self.event_source
1007            .add_listener(monitoring_state, monitoring_signals, filter, notify_fn)
1008    }
1009
1010    pub fn add_legacy_listener(&mut self, ptr: HostTreePointer<c::StatusListener>) {
1011        self.event_source.add_legacy_listener(ptr);
1012    }
1013
1014    pub fn remove_legacy_listener(&mut self, ptr: *mut c::StatusListener) {
1015        self.event_source.remove_legacy_listener(ptr);
1016    }
1017
1018    pub fn state(&self) -> FileState {
1019        self.file_state
1020    }
1021
1022    fn update_state(
1023        &mut self,
1024        mask: FileState,
1025        state: FileState,
1026        signals: FileSignals,
1027        cb_queue: &mut CallbackQueue,
1028    ) {
1029        let old_state = self.file_state;
1030
1031        // remove the masked flags, then copy the masked flags
1032        self.file_state.remove(mask);
1033        self.file_state.insert(state & mask);
1034
1035        self.handle_state_change(old_state, signals, cb_queue);
1036    }
1037
1038    fn handle_state_change(
1039        &mut self,
1040        old_state: FileState,
1041        signals: FileSignals,
1042        cb_queue: &mut CallbackQueue,
1043    ) {
1044        let states_changed = self.file_state ^ old_state;
1045
1046        // if nothing changed
1047        if states_changed.is_empty() && signals.is_empty() {
1048            return;
1049        }
1050
1051        self.event_source
1052            .notify_listeners(self.file_state, states_changed, signals, cb_queue);
1053    }
1054}
1055
1056fn tcp_error_to_errno(error: tcp::TcpError) -> Errno {
1057    match error {
1058        tcp::TcpError::ResetSent => Errno::ECONNRESET,
1059        // TODO: when should this be ECONNREFUSED vs ECONNRESET? maybe we need more context?
1060        tcp::TcpError::ResetReceived => Errno::ECONNREFUSED,
1061        tcp::TcpError::ClosedWhileConnecting => Errno::ECONNRESET,
1062        tcp::TcpError::TimedOut => Errno::ETIMEDOUT,
1063    }
1064}
1065
1066/// Shared state stored in timers. This allows us to update existing timers when a child `TcpState`
1067/// is accept()ed and becomes owned by a new `TcpSocket` object.
1068#[derive(Debug)]
1069struct TcpDepsTimerState {
1070    /// The socket that the timer callback will run on.
1071    socket: Weak<AtomicRefCell<TcpSocket>>,
1072    /// Whether the timer callback should modify the state of this socket
1073    /// ([`TimerRegisteredBy::Parent`]), or one of its child sockets ([`TimerRegisteredBy::Child`]).
1074    registered_by: tcp::TimerRegisteredBy,
1075}
1076
1077/// The dependencies required by `TcpState::new()` so that the tcp code can interact with the
1078/// simulator.
1079#[derive(Debug)]
1080struct TcpDeps {
1081    /// State shared between all timers registered from this `TestEnvState`. This is needed since we
1082    /// may need to update existing pending timers when we accept() a `TcpState` from a listening
1083    /// state.
1084    timer_state: Arc<AtomicRefCell<TcpDepsTimerState>>,
1085}
1086
1087impl tcp::Dependencies for TcpDeps {
1088    type Instant = EmulatedTime;
1089    type Duration = SimulationTime;
1090
1091    fn register_timer(
1092        &self,
1093        time: Self::Instant,
1094        f: impl FnOnce(&mut tcp::TcpState<Self>, tcp::TimerRegisteredBy) + Send + Sync + 'static,
1095    ) {
1096        // make sure the socket is kept alive in the closure while the timer is waiting to be run
1097        // (don't store a weak reference), otherwise the socket may have already been dropped and
1098        // the timer won't run
1099        // TODO: is this the behaviour we want?
1100        let timer_state = self.timer_state.borrow();
1101        let socket = timer_state.socket.upgrade().unwrap();
1102        let registered_by = timer_state.registered_by;
1103
1104        // This is needed because `TaskRef` takes a `Fn`, but we have a `FnOnce`. It would be nice
1105        // if we could schedule a task that is guaranteed to run only once so we could avoid this
1106        // extra allocation and atomic. Instead we'll panic if it does run more than once.
1107        let f = Arc::new(AtomicRefCell::new(Some(f)));
1108
1109        // schedule a task with the host
1110        Worker::with_active_host(|host| {
1111            let task = TaskRef::new(move |_host| {
1112                // take ownership of the task; will panic if the task is run more than once
1113                let f = f.borrow_mut().take().unwrap();
1114
1115                // run the original closure on the tcp state
1116                CallbackQueue::queue_and_run_with_legacy(|cb_queue| {
1117                    socket.borrow_mut().with_tcp_state(cb_queue, |state| {
1118                        f(state, registered_by);
1119                    })
1120                });
1121            });
1122
1123            host.schedule_task_at_emulated_time(task, time);
1124        })
1125        .unwrap();
1126    }
1127
1128    fn current_time(&self) -> Self::Instant {
1129        Worker::current_time().unwrap()
1130    }
1131
1132    fn fork(&self) -> Self {
1133        let timer_state = self.timer_state.borrow();
1134
1135        // if a child is trying to fork(), something has gone wrong
1136        assert_eq!(timer_state.registered_by, tcp::TimerRegisteredBy::Parent);
1137
1138        Self {
1139            timer_state: Arc::new(AtomicRefCell::new(TcpDepsTimerState {
1140                socket: timer_state.socket.clone(),
1141                registered_by: tcp::TimerRegisteredBy::Child,
1142            })),
1143        }
1144    }
1145}