tcp/
states.rs

1use std::collections::{HashMap, LinkedList};
2use std::io::{Read, Write};
3use std::net::SocketAddrV4;
4
5use crate::buffer::RecvQueue;
6use crate::connection::Connection;
7use crate::seq::Seq;
8use crate::util::remove_from_list;
9use crate::util::time::Duration;
10use crate::{
11    AcceptError, AcceptedTcpState, CloseError, ConnectError, Dependencies, ListenError, Payload,
12    PollState, PopPacketError, PushPacketError, RecvError, RstCloseError, SendError, Shutdown,
13    ShutdownError, TcpConfig, TcpError, TcpFlags, TcpHeader, TcpState, TcpStateEnum, TcpStateTrait,
14    TimerRegisteredBy,
15};
16
17// state structs
18
19/// The initial state of the TCP socket. While it's not a part of the official TCP state diagram, we
20/// don't want to overload the "closed" state to mean both a closed socket and a never used socket,
21/// since we don't allow TCP socket re-use.
22#[derive(Debug)]
23pub struct InitState<X: Dependencies> {
24    pub(crate) common: Common<X>,
25    pub(crate) config: TcpConfig,
26}
27
28#[derive(Debug)]
29pub struct ListenState<X: Dependencies> {
30    pub(crate) common: Common<X>,
31    pub(crate) config: TcpConfig,
32    pub(crate) max_backlog: u32,
33    pub(crate) send_buffer: LinkedList<TcpHeader>,
34    /// Child TCP states.
35    ///
36    /// Child states should only be mutated through the [`with_child`](Self::with_child) method to
37    /// ensure that this parent stays in sync with the child.
38    pub(crate) children: slotmap::DenseSlotMap<ChildTcpKey, ChildEntry<X>>,
39    /// A map from 4 tuple (source address, destination address) to child. Packets received from the
40    /// source address will be forwarded to the child.
41    pub(crate) conn_map: HashMap<RemoteLocalPair, ChildTcpKey>,
42    /// A queue of child TCP states in the "established" state, ready to be accept()ed.
43    pub(crate) accept_queue: LinkedList<ChildTcpKey>,
44    /// A list of child TCP states that want to send a packet.
45    pub(crate) to_send: LinkedList<ChildTcpKey>,
46}
47
48#[derive(Debug)]
49pub struct SynSentState<X: Dependencies> {
50    pub(crate) common: Common<X>,
51    pub(crate) connection: Connection<X::Instant>,
52}
53
54#[derive(Debug)]
55pub struct SynReceivedState<X: Dependencies> {
56    pub(crate) common: Common<X>,
57    pub(crate) connection: Connection<X::Instant>,
58}
59
60#[derive(Debug)]
61pub struct EstablishedState<X: Dependencies> {
62    pub(crate) common: Common<X>,
63    pub(crate) connection: Connection<X::Instant>,
64}
65
66#[derive(Debug)]
67pub struct FinWaitOneState<X: Dependencies> {
68    pub(crate) common: Common<X>,
69    pub(crate) connection: Connection<X::Instant>,
70}
71
72#[derive(Debug)]
73pub struct FinWaitTwoState<X: Dependencies> {
74    pub(crate) common: Common<X>,
75    pub(crate) connection: Connection<X::Instant>,
76}
77
78#[derive(Debug)]
79pub struct ClosingState<X: Dependencies> {
80    pub(crate) common: Common<X>,
81    pub(crate) connection: Connection<X::Instant>,
82}
83
84#[derive(Debug)]
85pub struct TimeWaitState<X: Dependencies> {
86    pub(crate) common: Common<X>,
87    pub(crate) connection: Connection<X::Instant>,
88}
89
90#[derive(Debug)]
91pub struct CloseWaitState<X: Dependencies> {
92    pub(crate) common: Common<X>,
93    pub(crate) connection: Connection<X::Instant>,
94}
95
96#[derive(Debug)]
97pub struct LastAckState<X: Dependencies> {
98    pub(crate) common: Common<X>,
99    pub(crate) connection: Connection<X::Instant>,
100}
101
102/// A state for sockets that need to send RST packets before closing. While it's not a part of the
103/// official TCP state diagram, we need to be able to buffer RST packets to send. We can't buffer
104/// RST packets in the "closed" state since the "closed" state is not allowed to send packets, so we
105/// use this as an intermediate state before we move to the "closed" state. We may need to buffer
106/// several RST packets; for example states in the "listening" state might need to send an RST
107/// packet for each child.
108#[derive(Debug)]
109pub struct RstState<X: Dependencies> {
110    pub(crate) common: Common<X>,
111    pub(crate) send_buffer: LinkedList<TcpHeader>,
112    /// Was the socket previously connected? Should be `true` for any states that have previously
113    /// been in the "syn-sent" or "syn-received" states. The connection does not need to have been
114    /// successful (for example it may have timed out in the "syn-sent" state or may have been
115    /// reset).
116    pub(crate) was_connected: bool,
117}
118
119#[derive(Debug)]
120pub struct ClosedState<X: Dependencies> {
121    pub(crate) common: Common<X>,
122    pub(crate) recv_buffer: RecvQueue,
123    /// Was the socket previously connected? Should be `true` for any states that have previously
124    /// been in the "syn-sent" or "syn-received" states. The connection does not need to have been
125    /// successful (for example it may have timed out in the "syn-sent" state or may have been
126    /// reset).
127    pub(crate) was_connected: bool,
128}
129
130// other helper types
131
132/// Indicates that no child exists for the given [key](ChildTcpKey).
133#[derive(Copy, Clone, Debug, Eq, PartialEq)]
134struct ChildNotFound;
135
136#[derive(Debug)]
137pub(crate) struct Common<X: Dependencies> {
138    pub(crate) deps: X,
139    /// If the current state is a child of a parent state, this should be the key that the parent
140    /// can use to lookup ths child state.
141    pub(crate) child_key: Option<ChildTcpKey>,
142    pub(crate) error: Option<TcpError>,
143}
144
145impl<X: Dependencies> Common<X> {
146    /// Register a timer for this state.
147    ///
148    /// This method will make sure that the callback gets run on the correct state, even if called
149    /// by a child state.
150    pub fn register_timer(
151        &self,
152        time: X::Instant,
153        f: impl FnOnce(TcpStateEnum<X>) -> TcpStateEnum<X> + Send + Sync + 'static,
154    ) {
155        // the handle that identifies this state if the state is a child of some parent state
156        let child_key = self.child_key;
157
158        // takes an owned `TcpStateEnum` and returns a `TcpStateEnum`
159        let timer_cb_inner = move |mut parent_state, state_type| {
160            match state_type {
161                // we're the parent and the timer was registered by us
162                TimerRegisteredBy::Parent => f(parent_state),
163                // we're the parent and the timer was registered by a child
164                TimerRegisteredBy::Child => {
165                    // if not in the listening state anymore, then the child must not exist
166                    let TcpStateEnum::Listen(parent_listen_state) = &mut parent_state else {
167                        // do nothing
168                        return parent_state;
169                    };
170
171                    // we need to lookup the child in `state` and run f() on the child's state
172                    // instead
173
174                    let child_key = child_key.expect(
175                        "The timer was supposedly registered by a child state, but there was no \
176                        key to identify the child",
177                    );
178
179                    let rv = parent_listen_state.with_child(child_key, |state| (f(state), ()));
180
181                    // in practice we want to ignore the error, but by doing a match here we make
182                    // sure that if the return type of `with_child` changes in the future, this code
183                    // will break and we can update it
184                    #[allow(clippy::single_match)]
185                    match rv {
186                        Ok(()) => {}
187                        // we ignore this since the child may have been closed
188                        Err(ChildNotFound) => {}
189                    }
190
191                    parent_state
192                }
193            }
194        };
195
196        // mutates a reference to a `TcpState` (this is a separate closure since it saves us two
197        // levels of indentation in the inner closure above)
198        let timer_cb = move |parent_state: &mut TcpState<X>, state_type| {
199            parent_state.with_state(|state| (timer_cb_inner(state, state_type), ()))
200        };
201
202        self.deps.register_timer(time, timer_cb);
203    }
204
205    pub fn current_time(&self) -> X::Instant {
206        self.deps.current_time()
207    }
208
209    /// Returns true if the error was set, or false if the error was previously set and was not
210    /// modified.
211    pub fn set_error_if_unset(&mut self, new_error: TcpError) -> bool {
212        if self.error.is_none() {
213            self.error = Some(new_error);
214            return true;
215        }
216
217        false
218    }
219}
220
221/// A pair of remote and local addresses, typically used to represent a connection (the 4-tuple).
222#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
223pub(crate) struct RemoteLocalPair {
224    /// The remote address (where a received packet was addressed from, or the address we're sending
225    /// a packet to).
226    remote: SocketAddrV4,
227    /// The local address (where a received packet was addressed to, or the address we're sending a
228    /// packet from).
229    local: SocketAddrV4,
230}
231
232impl RemoteLocalPair {
233    pub fn new(remote: SocketAddrV4, local: SocketAddrV4) -> Self {
234        Self { remote, local }
235    }
236}
237
238slotmap::new_key_type! { pub(crate) struct ChildTcpKey; }
239
240#[derive(Debug)]
241pub(crate) struct ChildEntry<X: Dependencies> {
242    /// The `Option` is required so that we can run [`TcpState`] methods that require `self`, for
243    /// example `child.push_packet()`.
244    state: Option<TcpStateEnum<X>>,
245    conn_addrs: RemoteLocalPair,
246}
247
248// state implementations
249
250impl<X: Dependencies> InitState<X> {
251    pub fn new(deps: X, config: TcpConfig) -> Self {
252        let common = Common {
253            deps,
254            child_key: None,
255            error: None,
256        };
257
258        InitState { common, config }
259    }
260}
261
262impl<X: Dependencies> TcpStateTrait<X> for InitState<X> {
263    fn close(self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
264        let new_state = ClosedState::new(self.common, None, /* was_connected= */ false);
265        (new_state.into(), Ok(()))
266    }
267
268    fn rst_close(self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
269        // no need to send a RST; closing immediately
270        let new_state = ClosedState::new(self.common, None, /* was_connected= */ false);
271        (new_state.into(), Ok(()))
272    }
273
274    fn shutdown(self, _how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
275        (self.into(), Err(ShutdownError::NotConnected))
276    }
277
278    fn listen<T, E>(
279        self,
280        backlog: u32,
281        associate_fn: impl FnOnce() -> Result<T, E>,
282    ) -> (TcpStateEnum<X>, Result<T, ListenError<E>>) {
283        let rv = match associate_fn() {
284            Ok(x) => x,
285            Err(e) => return (self.into(), Err(ListenError::FailedAssociation(e))),
286        };
287
288        // linux uses a queue limit of one greater than the provided backlog
289        let max_backlog = backlog.saturating_add(1);
290
291        let new_state = ListenState::new(self.common, self.config, max_backlog);
292        (new_state.into(), Ok(rv))
293    }
294
295    fn connect<T, E>(
296        self,
297        remote_addr: SocketAddrV4,
298        associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
299    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
300        let assoc_result = associate_fn();
301
302        let (local_addr, assoc_result) = match assoc_result {
303            Ok((local_addr, assoc_result)) => (local_addr, assoc_result),
304            Err(e) => return (self.into(), Err(ConnectError::FailedAssociation(e))),
305        };
306
307        assert!(!local_addr.ip().is_unspecified());
308
309        let connection = Connection::new(local_addr, remote_addr, Seq::new(0), self.config);
310
311        let new_state = SynSentState::new(self.common, connection);
312        (new_state.into(), Ok(assoc_result))
313    }
314
315    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
316        (self.into(), Err(SendError::NotConnected))
317    }
318
319    fn recv(self, _writer: impl Write, _len: usize) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
320        (self.into(), Err(RecvError::NotConnected))
321    }
322
323    fn clear_error(&mut self) -> Option<TcpError> {
324        self.common.error.take()
325    }
326
327    fn poll(&self) -> PollState {
328        let mut poll_state = PollState::empty();
329
330        if self.common.error.is_some() {
331            poll_state.insert(PollState::ERROR);
332        }
333
334        poll_state
335    }
336
337    fn wants_to_send(&self) -> bool {
338        false
339    }
340
341    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
342        None
343    }
344}
345
346impl<X: Dependencies> ListenState<X> {
347    fn new(common: Common<X>, config: TcpConfig, max_backlog: u32) -> Self {
348        ListenState {
349            common,
350            config,
351            max_backlog,
352            send_buffer: LinkedList::new(),
353            children: slotmap::DenseSlotMap::with_key(),
354            conn_map: HashMap::new(),
355            accept_queue: LinkedList::new(),
356            to_send: LinkedList::new(),
357        }
358    }
359
360    /// Register a new child TCP state for a new incoming connection.
361    fn register_child(&mut self, header: &TcpHeader, payload: Payload) -> ChildTcpKey {
362        let conn_addrs = RemoteLocalPair::new(header.src(), header.dst());
363
364        let key = self.children.insert_with_key(|key| {
365            let common = Common {
366                deps: self.common.deps.fork(),
367                child_key: Some(key),
368                error: None,
369            };
370
371            assert!(header.flags.contains(TcpFlags::SYN));
372            assert!(!header.flags.contains(TcpFlags::RST));
373
374            let mut connection =
375                Connection::new(header.dst(), header.src(), Seq::new(0), self.config);
376            connection.push_packet(header, payload).unwrap();
377
378            let new_tcp = SynReceivedState::new(common, connection);
379
380            ChildEntry {
381                state: Some(new_tcp.into()),
382                conn_addrs,
383            }
384        });
385
386        assert!(self.conn_map.insert(conn_addrs, key).is_none());
387
388        // make sure the child is added to all of the correct lists
389        self.sync_child(key).unwrap();
390
391        key
392    }
393
394    /// Make sure the parent's state is synchronized with the child's state. For example if the
395    /// child is in the "established" state, it should be in the parent's accept queue.
396    fn sync_child(&mut self, key: ChildTcpKey) -> Result<(), ChildNotFound> {
397        let is_closed;
398
399        {
400            let entry = self.children.get_mut(key).ok_or(ChildNotFound)?;
401            let child = &mut entry.state;
402            let conn_addrs = &entry.conn_addrs;
403
404            // add to or remove from the `to_send` list
405            if child.as_ref().unwrap().wants_to_send() {
406                // if it wants to send a packet but is not in the `to_send` list
407                if !self.to_send.contains(&key) {
408                    // add to the `to_send` list
409                    self.to_send.push_back(key);
410                }
411            } else {
412                // doesn't want to send a packet, remove from the `to_send` list
413                remove_from_list(&mut self.to_send, &key);
414            }
415
416            // add to or remove from the accept queue
417            if matches!(
418                child.as_ref().unwrap(),
419                TcpStateEnum::Established(_) | TcpStateEnum::CloseWait(_)
420            ) {
421                // if in the "established" or "close-wait" state, but not in the accept queue
422                if !self.accept_queue.contains(&key) {
423                    // add to the accept queue
424                    self.accept_queue.push_back(key);
425                }
426            } else {
427                // not in the "established" or "close-wait" state; remove from the accept queue
428                remove_from_list(&mut self.accept_queue, &key);
429            }
430
431            // make sure that it's contained in the src map
432            assert!(self.conn_map.contains_key(conn_addrs));
433            debug_assert_eq!(self.conn_map.get(conn_addrs).unwrap(), &key);
434
435            is_closed = child.as_ref().unwrap().poll().contains(PollState::CLOSED);
436        }
437
438        // if the child is closed, we can drop it
439        if is_closed {
440            self.remove_child(key).unwrap();
441        }
442
443        Ok(())
444    }
445
446    /// Remove a child state and all references to it (except timers). Returns `None` if there was
447    /// no child with the given key.
448    fn remove_child(&mut self, key: ChildTcpKey) -> Option<TcpStateEnum<X>> {
449        let entry = self.children.remove(key)?;
450        let child = entry.state.unwrap();
451        let conn_addrs = entry.conn_addrs;
452
453        // remove the child from any other lists/maps
454
455        remove_from_list(&mut self.accept_queue, &key);
456        remove_from_list(&mut self.to_send, &key);
457        assert_eq!(self.conn_map.remove(&conn_addrs), Some(key));
458
459        Some(child)
460    }
461
462    /// Get the child state.
463    fn child(&self, key: ChildTcpKey) -> Option<&TcpStateEnum<X>> {
464        self.children.get(key)?.state.as_ref()
465    }
466
467    /// Mutate the child's state, and automatically make sure that the parent's state is correctly
468    /// synced with the child's state (see [`sync_child`]).
469    fn with_child<T>(
470        &mut self,
471        key: ChildTcpKey,
472        f: impl FnOnce(TcpStateEnum<X>) -> (TcpStateEnum<X>, T),
473    ) -> Result<T, ChildNotFound> {
474        let rv;
475
476        {
477            let child = &mut self.children.get_mut(key).ok_or(ChildNotFound)?.state;
478
479            // run the closure
480            let mut state = child.take().unwrap();
481            (state, rv) = f(state);
482            *child = Some(state);
483        }
484
485        self.sync_child(key).unwrap();
486
487        Ok(rv)
488    }
489}
490
491impl<X: Dependencies> TcpStateTrait<X> for ListenState<X> {
492    fn close(self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
493        let (new_state, rv) = self.rst_close();
494        assert!(rv.is_ok());
495        (new_state, Ok(()))
496    }
497
498    fn rst_close(mut self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
499        let child_keys = Vec::from_iter(self.children.keys());
500
501        for key in child_keys {
502            self.with_child(key, |child| child.rst_close())
503                .unwrap()
504                .unwrap();
505
506            // get any packets that it wants to send and add them to our send buffer; removing a
507            // packet may cause the child to close which will make `key` invalid, which is why we
508            // don't unwrap here
509            while let Ok(Ok((header, payload))) = self.with_child(key, |child| child.pop_packet()) {
510                assert!(payload.is_empty());
511                self.send_buffer.push_back(header);
512            }
513        }
514
515        // The `rst_close` should have moved the child states to either "closed" or "rst" and
516        // possibly queued some RST packets. Then we should have taken those packets from the child
517        // state and moved them to our buffer, which would have then moved all child states to
518        // "closed". Finally `with_child` would have seen that they closed and removed them from
519        // `self.children`.
520        assert!(self.children.is_empty());
521
522        // get all rst packets from our send buffer
523        let rst_packets: LinkedList<_> = self
524            .send_buffer
525            .into_iter()
526            .filter(|header| header.flags.contains(TcpFlags::RST))
527            .collect();
528
529        let new_state = if rst_packets.is_empty() {
530            // no RST packets to send, so go directly to the "closed" state
531            ClosedState::new(self.common, None, /* was_connected= */ false).into()
532        } else {
533            RstState::new(self.common, rst_packets, /* was_connected= */ false).into()
534        };
535
536        (new_state, Ok(()))
537    }
538
539    fn shutdown(self, _how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
540        // TODO: Linux will reset back to the initial state (allowing future connect(), listen(),
541        // etc for the same socket) for SHUT_RD or SHUT_RDWR. But this should probably be handled in
542        // a higher layer (for example having the socket replace this tcp state with a new tcp state
543        // object).
544
545        (self.into(), Err(ShutdownError::NotConnected))
546    }
547
548    fn listen<T, E>(
549        mut self,
550        backlog: u32,
551        associate_fn: impl FnOnce() -> Result<T, E>,
552    ) -> (TcpStateEnum<X>, Result<T, ListenError<E>>) {
553        // we don't need to associate, but we run this closure anyway; the caller can make this a
554        // no-op if it doesn't need to associate
555        let rv = match associate_fn() {
556            Ok(x) => x,
557            Err(e) => return (self.into(), Err(ListenError::FailedAssociation(e))),
558        };
559
560        // linux uses a limit of one greater than the provided backlog
561        let max_backlog = backlog.saturating_add(1);
562
563        // we're already listening, so must already be associated; just update the backlog
564        self.max_backlog = max_backlog;
565        (self.into(), Ok(rv))
566    }
567
568    fn connect<T, E>(
569        self,
570        _remote_addr: SocketAddrV4,
571        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
572    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
573        (self.into(), Err(ConnectError::IsListening))
574    }
575
576    fn accept(mut self) -> (TcpStateEnum<X>, Result<AcceptedTcpState<X>, AcceptError>) {
577        let Some(child_key) = self.accept_queue.pop_front() else {
578            return (self.into(), Err(AcceptError::NothingToAccept));
579        };
580
581        let child = self.remove_child(child_key).unwrap();
582
583        // if the child is in an acceptable state, it's wrapped in an `AcceptedTcpState` and
584        // returned to the caller
585        let accepted_state = match child.try_into() {
586            Ok(x) => x,
587            Err(child) => {
588                // the child is in a state that we can't return to the caller, so we messed up
589                // somewhere earlier
590                panic!("Unexpected child TCP state in accept queue: {:?}", child);
591            }
592        };
593
594        (self.into(), Ok(accepted_state))
595    }
596
597    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
598        (self.into(), Err(SendError::NotConnected))
599    }
600
601    fn recv(self, _writer: impl Write, _len: usize) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
602        (self.into(), Err(RecvError::NotConnected))
603    }
604
605    fn push_packet(
606        mut self,
607        header: &TcpHeader,
608        payload: Payload,
609    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
610        // In Linux there is conceptually the syn queue and the accept queue. When the application
611        // calls `listen()`, it passes a `backlog` argument. The question is: does this backlog
612        // apply to the syn queue, accept queue, or both? Some references[1] and the listen(2)[2]
613        // man page say that the backlog only applies to the accept queue, but other blogs[3,4] and
614        // stack overflow[5] say that it applies to both queues.
615        //
616        // The truth is probably more nuanced, and Linux technically doesn't have a "syn queue", but
617        // this should be good enough for us. We'll apply the backlog as a limit to both queues
618        // (each queue can hold `backlog` entries). In our case, the "syn queue" length is
619        // `children.len() - accept_queue.len()`.
620        //
621        // If the accept queue is full, the application is slow at accept()ing new connections. As a
622        // push-back mechanism drop all incoming SYN packets, and incoming ACK packets that are
623        // intended for a child in the "syn-received" state (since they would then get added to the
624        // accept queue, but the accept queue is full). If the syn queue is full, drop all incoming
625        // SYN packets (we don't support SYN cookies). This seems to be along the lines of what
626        // Linux does.[4]
627        //
628        // [1]: https://veithen.io/2014/01/01/how-tcp-backlog-works-in-linux.html
629        // [2]: https://man7.org/linux/man-pages/man2/listen.2.html
630        // [3]: https://arthurchiao.art/blog/tcp-listen-a-tale-of-two-queues/
631        // [4]: https://blog.cloudflare.com/syn-packet-handling-in-the-wild/
632        // [5]: https://stackoverflow.com/questions/58183847/
633
634        let max_backlog = self.max_backlog.try_into().unwrap();
635        let syn_queue_len = self
636            .children
637            .len()
638            .checked_sub(self.accept_queue.len())
639            .unwrap();
640        let accept_queue_full = self.accept_queue.len() >= max_backlog;
641        let syn_queue_full = syn_queue_len >= max_backlog;
642
643        // if either queue is full, drop all SYN packets
644        if header.flags.contains(TcpFlags::SYN) && (accept_queue_full || syn_queue_full) {
645            return (self.into(), Ok(0));
646        }
647
648        let conn_addrs = RemoteLocalPair::new(header.src(), header.dst());
649
650        // forward the packet to a child state if it's from a known src address
651        if let Some(child_key) = self.conn_map.get(&conn_addrs) {
652            // if in the "syn-received" state, is an ACK packet, and the accept queue is full, drop
653            // the packet
654            if matches!(self.child(*child_key), Some(TcpStateEnum::SynReceived(_)))
655                && header.flags.contains(TcpFlags::ACK)
656                && accept_queue_full
657            {
658                return (self.into(), Ok(0));
659            }
660
661            // forward the packet to the child state
662            let rv = self
663                .with_child(*child_key, |state| state.push_packet(header, payload))
664                .unwrap();
665
666            // propagate any error from the child to the caller
667            return (self.into(), rv);
668        }
669
670        // this packet is meant for the listener, or for a child that no longer exists
671
672        // drop non-SYN packets
673        if !header.flags.contains(TcpFlags::SYN) {
674            // it's either for an old child that no longer exists, or is for the listener and
675            // doesn't have the SYN flag for some reason
676            return (self.into(), Ok(0));
677        }
678
679        // we received a SYN packet, so register a new child in the "syn-received" state
680        self.register_child(header, payload);
681
682        (self.into(), Ok(0))
683    }
684
685    fn pop_packet(
686        mut self,
687    ) -> (
688        TcpStateEnum<X>,
689        Result<(TcpHeader, Payload), PopPacketError>,
690    ) {
691        if let Some(header) = self.send_buffer.pop_front() {
692            return (self.into(), Ok((header, Payload::default())));
693        }
694
695        if let Some(child_key) = self.to_send.pop_front() {
696            let rv = self
697                .with_child(child_key, |state| state.pop_packet())
698                .unwrap();
699
700            // if the child was in the list, then we'll assume it must have a packet to send
701            let (header, payload) = rv.unwrap();
702
703            // might as well check this
704            debug_assert!(payload.is_empty());
705
706            return (self.into(), Ok((header, payload)));
707        }
708
709        (self.into(), Err(PopPacketError::NoPacket))
710    }
711
712    fn clear_error(&mut self) -> Option<TcpError> {
713        self.common.error.take()
714    }
715
716    fn poll(&self) -> PollState {
717        let mut poll_state = PollState::LISTENING;
718
719        if !self.accept_queue.is_empty() {
720            poll_state.insert(PollState::READY_TO_ACCEPT);
721        }
722
723        if self.common.error.is_some() {
724            poll_state.insert(PollState::ERROR);
725        }
726
727        poll_state
728    }
729
730    fn wants_to_send(&self) -> bool {
731        !self.send_buffer.is_empty() || !self.to_send.is_empty()
732    }
733
734    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
735        None
736    }
737}
738
739impl<X: Dependencies> SynSentState<X> {
740    fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
741        let state = SynSentState { common, connection };
742
743        // if still in the "syn-sent" state after 60 seconds, close it
744        let timeout = state.common.current_time() + X::Duration::from_secs(60);
745        state.common.register_timer(timeout, |state| {
746            if let TcpStateEnum::SynSent(mut state) = state {
747                state.common.error = Some(TcpError::TimedOut);
748
749                let (state, rv) = state.rst_close();
750                assert!(rv.is_ok());
751                state
752            } else {
753                state
754            }
755        });
756
757        state
758    }
759}
760
761impl<X: Dependencies> TcpStateTrait<X> for SynSentState<X> {
762    fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
763        // we haven't received a SYN yet, so we can't have received data and don't
764        // need to send an RST
765        debug_assert!(!self.connection.recv_buf_has_data());
766
767        self.common
768            .set_error_if_unset(TcpError::ClosedWhileConnecting);
769
770        let new_state = ClosedState::new(self.common, None, /* was_connected= */ true);
771        (new_state.into(), Ok(()))
772    }
773
774    fn rst_close(mut self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
775        // we haven't received a SYN yet, so we can't have received data and don't
776        // need to send an RST
777        debug_assert!(!self.connection.recv_buf_has_data());
778
779        self.common
780            .set_error_if_unset(TcpError::ClosedWhileConnecting);
781
782        let new_state = ClosedState::new(self.common, None, /* was_connected= */ true);
783        (new_state.into(), Ok(()))
784    }
785
786    fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
787        if how == Shutdown::Read || how == Shutdown::Both {
788            self.connection.send_rst_if_recv_payload()
789        }
790
791        if how == Shutdown::Write || how == Shutdown::Both {
792            // we haven't received a SYN yet, so we can't have received data and don't
793            // need to send an RST
794            debug_assert!(!self.connection.recv_buf_has_data());
795
796            self.common
797                .set_error_if_unset(TcpError::ClosedWhileConnecting);
798
799            let new_state = ClosedState::new(self.common, None, /* was_connected= */ true);
800            return (new_state.into(), Ok(()));
801        }
802
803        (self.into(), Ok(()))
804    }
805
806    fn connect<T, E>(
807        self,
808        _remote_addr: SocketAddrV4,
809        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
810    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
811        (self.into(), Err(ConnectError::InProgress))
812    }
813
814    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
815        (self.into(), Err(SendError::NotConnected))
816    }
817
818    fn recv(self, _writer: impl Write, _len: usize) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
819        (self.into(), Err(RecvError::NotConnected))
820    }
821
822    fn push_packet(
823        mut self,
824        header: &TcpHeader,
825        payload: Payload,
826    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
827        // make sure that the packet src/dst addresses are valid for this connection
828        if !self.connection.packet_addrs_match(header) {
829            // must drop the packet
830            return (self.into(), Ok(0));
831        }
832
833        let pushed_len = match self.connection.push_packet(header, payload) {
834            Ok(v) => v,
835            Err(e) => return (self.into(), Err(e)),
836        };
837
838        // if the connection was reset
839        if self.connection.is_reset() {
840            if header.flags.contains(TcpFlags::RST) {
841                self.common.set_error_if_unset(TcpError::ResetReceived);
842            }
843
844            let new_state = connection_was_reset(self.common, self.connection);
845            return (new_state, Ok(pushed_len));
846        }
847
848        // if received SYN and ACK (active open), move to the "established" state
849        if self.connection.received_syn() && self.connection.syn_was_acked() {
850            let new_state = EstablishedState::new(self.common, self.connection);
851            return (new_state.into(), Ok(pushed_len));
852        }
853
854        // if received SYN and no ACK (simultaneous open), move to the "syn-received" state
855        if self.connection.received_syn() {
856            let new_state = SynReceivedState::new(self.common, self.connection);
857            return (new_state.into(), Ok(pushed_len));
858        }
859
860        // TODO: unsure what to do otherwise; just dropping the packet
861
862        (self.into(), Ok(pushed_len))
863    }
864
865    fn pop_packet(
866        mut self,
867    ) -> (
868        TcpStateEnum<X>,
869        Result<(TcpHeader, Payload), PopPacketError>,
870    ) {
871        let rv = self.connection.pop_packet(self.common.current_time());
872        (self.into(), rv)
873    }
874
875    fn clear_error(&mut self) -> Option<TcpError> {
876        self.common.error.take()
877    }
878
879    fn poll(&self) -> PollState {
880        let mut poll_state = PollState::CONNECTING;
881
882        if self.common.error.is_some() {
883            poll_state.insert(PollState::ERROR);
884        }
885
886        poll_state
887    }
888
889    fn wants_to_send(&self) -> bool {
890        self.connection.wants_to_send()
891    }
892
893    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
894        Some((self.connection.local_addr, self.connection.remote_addr))
895    }
896}
897
898impl<X: Dependencies> SynReceivedState<X> {
899    fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
900        let state = SynReceivedState { common, connection };
901
902        // if still in the "syn-received" state after 60 seconds, close it with a RST
903        let timeout = state.common.current_time() + X::Duration::from_secs(60);
904        state.common.register_timer(timeout, |state| {
905            if let TcpStateEnum::SynReceived(mut state) = state {
906                state.common.error = Some(TcpError::TimedOut);
907
908                let (state, rv) = state.rst_close();
909                assert!(rv.is_ok());
910                return state;
911            }
912
913            state
914        });
915
916        state
917    }
918}
919
920impl<X: Dependencies> TcpStateTrait<X> for SynReceivedState<X> {
921    fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
922        let new_state = if self.connection.recv_buf_has_data() {
923            // send a RST if there is still data in the receive buffer
924            reset_connection(self.common, self.connection).into()
925        } else {
926            // send a FIN packet
927            self.connection.send_fin();
928
929            self.common
930                .set_error_if_unset(TcpError::ClosedWhileConnecting);
931
932            // if the connection receives any more data, it should send an RST
933            self.connection.send_rst_if_recv_payload();
934
935            FinWaitOneState::new(self.common, self.connection).into()
936        };
937
938        (new_state, Ok(()))
939    }
940
941    fn rst_close(self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
942        let new_state = reset_connection(self.common, self.connection);
943        (new_state.into(), Ok(()))
944    }
945
946    fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
947        if how == Shutdown::Read || how == Shutdown::Both {
948            self.connection.send_rst_if_recv_payload()
949        }
950
951        if how == Shutdown::Write || how == Shutdown::Both {
952            // send a FIN packet
953            self.connection.send_fin();
954
955            self.common
956                .set_error_if_unset(TcpError::ClosedWhileConnecting);
957
958            let new_state = FinWaitOneState::new(self.common, self.connection);
959            return (new_state.into(), Ok(()));
960        }
961
962        (self.into(), Ok(()))
963    }
964
965    fn connect<T, E>(
966        self,
967        _remote_addr: SocketAddrV4,
968        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
969    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
970        (self.into(), Err(ConnectError::InProgress))
971    }
972
973    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
974        (self.into(), Err(SendError::NotConnected))
975    }
976
977    fn recv(self, _writer: impl Write, _len: usize) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
978        (self.into(), Err(RecvError::NotConnected))
979    }
980
981    fn push_packet(
982        mut self,
983        header: &TcpHeader,
984        payload: Payload,
985    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
986        // waiting for the ACK for our SYN
987
988        // make sure that the packet src/dst addresses are valid for this connection
989        if !self.connection.packet_addrs_match(header) {
990            // must drop the packet
991            return (self.into(), Ok(0));
992        }
993
994        let pushed_len = match self.connection.push_packet(header, payload) {
995            Ok(v) => v,
996            Err(e) => return (self.into(), Err(e)),
997        };
998
999        // if the connection was reset
1000        if self.connection.is_reset() {
1001            if header.flags.contains(TcpFlags::RST) {
1002                self.common.set_error_if_unset(TcpError::ResetReceived);
1003            }
1004
1005            let new_state = connection_was_reset(self.common, self.connection);
1006            return (new_state, Ok(pushed_len));
1007        }
1008
1009        // if received ACK, move to the "established" state
1010        if self.connection.syn_was_acked() {
1011            let new_state = EstablishedState::new(self.common, self.connection);
1012            return (new_state.into(), Ok(pushed_len));
1013        }
1014
1015        (self.into(), Ok(pushed_len))
1016    }
1017
1018    fn pop_packet(
1019        mut self,
1020    ) -> (
1021        TcpStateEnum<X>,
1022        Result<(TcpHeader, Payload), PopPacketError>,
1023    ) {
1024        let rv = self.connection.pop_packet(self.common.current_time());
1025        (self.into(), rv)
1026    }
1027
1028    fn clear_error(&mut self) -> Option<TcpError> {
1029        self.common.error.take()
1030    }
1031
1032    fn poll(&self) -> PollState {
1033        let mut poll_state = PollState::CONNECTING;
1034
1035        if self.common.error.is_some() {
1036            poll_state.insert(PollState::ERROR);
1037        }
1038
1039        poll_state
1040    }
1041
1042    fn wants_to_send(&self) -> bool {
1043        self.connection.wants_to_send()
1044    }
1045
1046    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
1047        Some((self.connection.local_addr, self.connection.remote_addr))
1048    }
1049}
1050
1051impl<X: Dependencies> EstablishedState<X> {
1052    fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
1053        EstablishedState { common, connection }
1054    }
1055}
1056
1057impl<X: Dependencies> TcpStateTrait<X> for EstablishedState<X> {
1058    fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
1059        let new_state = if self.connection.recv_buf_has_data() {
1060            // send a RST if there is still data in the receive buffer
1061            reset_connection(self.common, self.connection).into()
1062        } else {
1063            // send a FIN packet
1064            self.connection.send_fin();
1065
1066            // if the connection receives any more data, it should send an RST
1067            self.connection.send_rst_if_recv_payload();
1068
1069            FinWaitOneState::new(self.common, self.connection).into()
1070        };
1071
1072        (new_state, Ok(()))
1073    }
1074
1075    fn rst_close(self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
1076        let new_state = reset_connection(self.common, self.connection);
1077        (new_state.into(), Ok(()))
1078    }
1079
1080    fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
1081        if how == Shutdown::Read || how == Shutdown::Both {
1082            self.connection.send_rst_if_recv_payload()
1083        }
1084
1085        if how == Shutdown::Write || how == Shutdown::Both {
1086            // send a FIN packet
1087            self.connection.send_fin();
1088
1089            let new_state = FinWaitOneState::new(self.common, self.connection);
1090            return (new_state.into(), Ok(()));
1091        }
1092
1093        (self.into(), Ok(()))
1094    }
1095
1096    fn connect<T, E>(
1097        self,
1098        _remote_addr: SocketAddrV4,
1099        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
1100    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
1101        (self.into(), Err(ConnectError::AlreadyConnected))
1102    }
1103
1104    fn send(
1105        mut self,
1106        reader: impl Read,
1107        len: usize,
1108    ) -> (TcpStateEnum<X>, Result<usize, SendError>) {
1109        let rv = self.connection.send(reader, len);
1110        (self.into(), rv)
1111    }
1112
1113    fn recv(
1114        mut self,
1115        writer: impl Write,
1116        len: usize,
1117    ) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
1118        let rv = self.connection.recv(writer, len);
1119        (self.into(), rv)
1120    }
1121
1122    fn push_packet(
1123        mut self,
1124        header: &TcpHeader,
1125        payload: Payload,
1126    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
1127        // make sure that the packet src/dst addresses are valid for this connection
1128        if !self.connection.packet_addrs_match(header) {
1129            // must drop the packet
1130            return (self.into(), Ok(0));
1131        }
1132
1133        let pushed_len = match self.connection.push_packet(header, payload) {
1134            Ok(v) => v,
1135            Err(e) => return (self.into(), Err(e)),
1136        };
1137
1138        // if the connection was reset
1139        if self.connection.is_reset() {
1140            if header.flags.contains(TcpFlags::RST) {
1141                self.common.set_error_if_unset(TcpError::ResetReceived);
1142            }
1143
1144            let new_state = connection_was_reset(self.common, self.connection);
1145            return (new_state, Ok(pushed_len));
1146        }
1147
1148        // if received FIN, move to the "close-wait" state
1149        if self.connection.received_fin() {
1150            let new_state = CloseWaitState::new(self.common, self.connection);
1151            return (new_state.into(), Ok(pushed_len));
1152        }
1153
1154        (self.into(), Ok(pushed_len))
1155    }
1156
1157    fn pop_packet(
1158        mut self,
1159    ) -> (
1160        TcpStateEnum<X>,
1161        Result<(TcpHeader, Payload), PopPacketError>,
1162    ) {
1163        let rv = self.connection.pop_packet(self.common.current_time());
1164        (self.into(), rv)
1165    }
1166
1167    fn clear_error(&mut self) -> Option<TcpError> {
1168        self.common.error.take()
1169    }
1170
1171    fn poll(&self) -> PollState {
1172        let mut poll_state = PollState::CONNECTED;
1173
1174        if self.connection.send_buf_has_space() {
1175            poll_state.insert(PollState::WRITABLE);
1176        }
1177
1178        if self.connection.recv_buf_has_data() {
1179            poll_state.insert(PollState::READABLE);
1180        }
1181
1182        if self.common.error.is_some() {
1183            poll_state.insert(PollState::ERROR);
1184        }
1185
1186        poll_state
1187    }
1188
1189    fn wants_to_send(&self) -> bool {
1190        self.connection.wants_to_send()
1191    }
1192
1193    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
1194        Some((self.connection.local_addr, self.connection.remote_addr))
1195    }
1196}
1197
1198impl<X: Dependencies> FinWaitOneState<X> {
1199    fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
1200        FinWaitOneState { common, connection }
1201    }
1202}
1203
1204impl<X: Dependencies> TcpStateTrait<X> for FinWaitOneState<X> {
1205    fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
1206        let new_state = if self.connection.recv_buf_has_data() {
1207            // send a RST if there is still data in the receive buffer
1208            reset_connection(self.common, self.connection).into()
1209        } else {
1210            // if the connection receives any more data, it should send an RST
1211            self.connection.send_rst_if_recv_payload();
1212
1213            // we're already in the process of closing (active close)
1214            self.into()
1215        };
1216
1217        (new_state, Ok(()))
1218    }
1219
1220    fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
1221        if how == Shutdown::Read || how == Shutdown::Both {
1222            self.connection.send_rst_if_recv_payload()
1223        }
1224
1225        if how == Shutdown::Write || how == Shutdown::Both {
1226            // we're already in the process of closing (active close)
1227        }
1228
1229        (self.into(), Ok(()))
1230    }
1231
1232    fn connect<T, E>(
1233        self,
1234        _remote_addr: SocketAddrV4,
1235        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
1236    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
1237        (self.into(), Err(ConnectError::AlreadyConnected))
1238    }
1239
1240    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
1241        (self.into(), Err(SendError::StreamClosed))
1242    }
1243
1244    fn recv(
1245        mut self,
1246        writer: impl Write,
1247        len: usize,
1248    ) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
1249        let rv = self.connection.recv(writer, len);
1250        (self.into(), rv)
1251    }
1252
1253    fn push_packet(
1254        mut self,
1255        header: &TcpHeader,
1256        payload: Payload,
1257    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
1258        // make sure that the packet src/dst addresses are valid for this connection
1259        if !self.connection.packet_addrs_match(header) {
1260            // must drop the packet
1261            return (self.into(), Ok(0));
1262        }
1263
1264        let pushed_len = match self.connection.push_packet(header, payload) {
1265            Ok(v) => v,
1266            Err(e) => return (self.into(), Err(e)),
1267        };
1268
1269        // if the connection was reset
1270        if self.connection.is_reset() {
1271            if header.flags.contains(TcpFlags::RST) {
1272                self.common.set_error_if_unset(TcpError::ResetReceived);
1273            }
1274
1275            let new_state = connection_was_reset(self.common, self.connection);
1276            return (new_state, Ok(pushed_len));
1277        }
1278
1279        // if received FIN and ACK, move to the "time-wait" state
1280        if self.connection.received_fin() && self.connection.fin_was_acked() {
1281            let new_state = TimeWaitState::new(self.common, self.connection);
1282            return (new_state.into(), Ok(pushed_len));
1283        }
1284
1285        // if received FIN, move to the "closing" state
1286        if self.connection.received_fin() {
1287            let new_state = ClosingState::new(self.common, self.connection);
1288            return (new_state.into(), Ok(pushed_len));
1289        }
1290
1291        // if received ACK, move to the "fin-wait-two" state
1292        if self.connection.fin_was_acked() {
1293            let new_state = FinWaitTwoState::new(self.common, self.connection);
1294            return (new_state.into(), Ok(pushed_len));
1295        }
1296
1297        (self.into(), Ok(pushed_len))
1298    }
1299
1300    fn pop_packet(
1301        mut self,
1302    ) -> (
1303        TcpStateEnum<X>,
1304        Result<(TcpHeader, Payload), PopPacketError>,
1305    ) {
1306        let rv = self.connection.pop_packet(self.common.current_time());
1307        (self.into(), rv)
1308    }
1309
1310    fn clear_error(&mut self) -> Option<TcpError> {
1311        self.common.error.take()
1312    }
1313
1314    fn poll(&self) -> PollState {
1315        let mut poll_state = PollState::CONNECTED;
1316
1317        if self.connection.recv_buf_has_data() {
1318            poll_state.insert(PollState::READABLE);
1319        }
1320
1321        // we've sent a FIN
1322        poll_state.insert(PollState::SEND_CLOSED);
1323        assert!(!poll_state.contains(PollState::WRITABLE));
1324
1325        if self.common.error.is_some() {
1326            poll_state.insert(PollState::ERROR);
1327        }
1328
1329        poll_state
1330    }
1331
1332    fn wants_to_send(&self) -> bool {
1333        self.connection.wants_to_send()
1334    }
1335
1336    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
1337        Some((self.connection.local_addr, self.connection.remote_addr))
1338    }
1339}
1340
1341impl<X: Dependencies> FinWaitTwoState<X> {
1342    fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
1343        FinWaitTwoState { common, connection }
1344    }
1345}
1346
1347impl<X: Dependencies> TcpStateTrait<X> for FinWaitTwoState<X> {
1348    fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
1349        let new_state = if self.connection.recv_buf_has_data() {
1350            // send a RST if there is still data in the receive buffer
1351            reset_connection(self.common, self.connection).into()
1352        } else {
1353            // if the connection receives any more data, it should send an RST
1354            self.connection.send_rst_if_recv_payload();
1355
1356            // we're already in the process of closing (active close)
1357            self.into()
1358        };
1359
1360        (new_state, Ok(()))
1361    }
1362
1363    fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
1364        if how == Shutdown::Read || how == Shutdown::Both {
1365            self.connection.send_rst_if_recv_payload()
1366        }
1367
1368        if how == Shutdown::Write || how == Shutdown::Both {
1369            // we're already in the process of closing (active close)
1370        }
1371
1372        (self.into(), Ok(()))
1373    }
1374
1375    fn connect<T, E>(
1376        self,
1377        _remote_addr: SocketAddrV4,
1378        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
1379    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
1380        (self.into(), Err(ConnectError::AlreadyConnected))
1381    }
1382
1383    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
1384        (self.into(), Err(SendError::StreamClosed))
1385    }
1386
1387    fn recv(
1388        mut self,
1389        writer: impl Write,
1390        len: usize,
1391    ) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
1392        let rv = self.connection.recv(writer, len);
1393        (self.into(), rv)
1394    }
1395
1396    fn push_packet(
1397        mut self,
1398        header: &TcpHeader,
1399        payload: Payload,
1400    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
1401        // make sure that the packet src/dst addresses are valid for this connection
1402        if !self.connection.packet_addrs_match(header) {
1403            // must drop the packet
1404            return (self.into(), Ok(0));
1405        }
1406
1407        let pushed_len = match self.connection.push_packet(header, payload) {
1408            Ok(v) => v,
1409            Err(e) => return (self.into(), Err(e)),
1410        };
1411
1412        // if the connection was reset
1413        if self.connection.is_reset() {
1414            if header.flags.contains(TcpFlags::RST) {
1415                self.common.set_error_if_unset(TcpError::ResetReceived);
1416            }
1417
1418            let new_state = connection_was_reset(self.common, self.connection);
1419            return (new_state, Ok(pushed_len));
1420        }
1421
1422        // if received FIN, move to the "time-wait" state
1423        if self.connection.received_fin() {
1424            let new_state = TimeWaitState::new(self.common, self.connection);
1425            return (new_state.into(), Ok(pushed_len));
1426        }
1427
1428        (self.into(), Ok(pushed_len))
1429    }
1430
1431    fn pop_packet(
1432        mut self,
1433    ) -> (
1434        TcpStateEnum<X>,
1435        Result<(TcpHeader, Payload), PopPacketError>,
1436    ) {
1437        let rv = self.connection.pop_packet(self.common.current_time());
1438        (self.into(), rv)
1439    }
1440
1441    fn clear_error(&mut self) -> Option<TcpError> {
1442        self.common.error.take()
1443    }
1444
1445    fn poll(&self) -> PollState {
1446        let mut poll_state = PollState::CONNECTED;
1447
1448        if self.connection.recv_buf_has_data() {
1449            poll_state.insert(PollState::READABLE);
1450        }
1451
1452        // we've sent a FIN
1453        poll_state.insert(PollState::SEND_CLOSED);
1454        assert!(!poll_state.contains(PollState::WRITABLE));
1455
1456        if self.common.error.is_some() {
1457            poll_state.insert(PollState::ERROR);
1458        }
1459
1460        poll_state
1461    }
1462
1463    fn wants_to_send(&self) -> bool {
1464        self.connection.wants_to_send()
1465    }
1466
1467    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
1468        Some((self.connection.local_addr, self.connection.remote_addr))
1469    }
1470}
1471
1472impl<X: Dependencies> ClosingState<X> {
1473    fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
1474        ClosingState { common, connection }
1475    }
1476}
1477
1478impl<X: Dependencies> TcpStateTrait<X> for ClosingState<X> {
1479    fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
1480        let new_state = if self.connection.recv_buf_has_data() {
1481            // send a RST if there is still data in the receive buffer
1482            reset_connection(self.common, self.connection).into()
1483        } else {
1484            // if the connection receives any more data, it should send an RST
1485            self.connection.send_rst_if_recv_payload();
1486
1487            // we're already in the process of closing (active close)
1488            self.into()
1489        };
1490
1491        (new_state, Ok(()))
1492    }
1493
1494    fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
1495        if how == Shutdown::Read || how == Shutdown::Both {
1496            self.connection.send_rst_if_recv_payload()
1497        }
1498
1499        if how == Shutdown::Write || how == Shutdown::Both {
1500            // we're already in the process of closing (active close)
1501        }
1502
1503        (self.into(), Ok(()))
1504    }
1505
1506    fn connect<T, E>(
1507        self,
1508        _remote_addr: SocketAddrV4,
1509        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
1510    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
1511        (self.into(), Err(ConnectError::AlreadyConnected))
1512    }
1513
1514    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
1515        (self.into(), Err(SendError::StreamClosed))
1516    }
1517
1518    fn recv(
1519        mut self,
1520        writer: impl Write,
1521        len: usize,
1522    ) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
1523        let rv = self.connection.recv(writer, len);
1524
1525        // the peer won't send any more data (it sent a FIN), so if there's no more data in the
1526        // buffer, inform the socket
1527        if matches!(rv, Err(RecvError::Empty)) {
1528            return (self.into(), Err(RecvError::StreamClosed));
1529        }
1530
1531        (self.into(), rv)
1532    }
1533
1534    fn push_packet(
1535        mut self,
1536        header: &TcpHeader,
1537        payload: Payload,
1538    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
1539        // make sure that the packet src/dst addresses are valid for this connection
1540        if !self.connection.packet_addrs_match(header) {
1541            // must drop the packet
1542            return (self.into(), Ok(0));
1543        }
1544
1545        let pushed_len = match self.connection.push_packet(header, payload) {
1546            Ok(v) => v,
1547            Err(e) => return (self.into(), Err(e)),
1548        };
1549
1550        // if the connection was reset
1551        if self.connection.is_reset() {
1552            if header.flags.contains(TcpFlags::RST) {
1553                self.common.set_error_if_unset(TcpError::ResetReceived);
1554            }
1555
1556            let new_state = connection_was_reset(self.common, self.connection);
1557            return (new_state, Ok(pushed_len));
1558        }
1559
1560        // if received ACK, move to the "time-wait" state
1561        if self.connection.fin_was_acked() {
1562            let new_state = TimeWaitState::new(self.common, self.connection);
1563            return (new_state.into(), Ok(pushed_len));
1564        }
1565
1566        // drop all other packets
1567
1568        (self.into(), Ok(pushed_len))
1569    }
1570
1571    fn pop_packet(
1572        mut self,
1573    ) -> (
1574        TcpStateEnum<X>,
1575        Result<(TcpHeader, Payload), PopPacketError>,
1576    ) {
1577        let rv = self.connection.pop_packet(self.common.current_time());
1578        (self.into(), rv)
1579    }
1580
1581    fn clear_error(&mut self) -> Option<TcpError> {
1582        self.common.error.take()
1583    }
1584
1585    fn poll(&self) -> PollState {
1586        let mut poll_state = PollState::CONNECTED;
1587
1588        // we've received a FIN
1589        poll_state.insert(PollState::RECV_CLOSED);
1590        if self.connection.recv_buf_has_data() {
1591            poll_state.insert(PollState::READABLE);
1592        }
1593
1594        // we've sent a FIN
1595        poll_state.insert(PollState::SEND_CLOSED);
1596        assert!(!poll_state.contains(PollState::WRITABLE));
1597
1598        if self.common.error.is_some() {
1599            poll_state.insert(PollState::ERROR);
1600        }
1601
1602        poll_state
1603    }
1604
1605    fn wants_to_send(&self) -> bool {
1606        self.connection.wants_to_send()
1607    }
1608
1609    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
1610        Some((self.connection.local_addr, self.connection.remote_addr))
1611    }
1612}
1613
1614impl<X: Dependencies> TimeWaitState<X> {
1615    fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
1616        let state = TimeWaitState { common, connection };
1617
1618        // taken from /proc/sys/net/ipv4/tcp_fin_timeout
1619        let timeout = X::Duration::from_secs(60);
1620
1621        // if still in the "time-wait" state after the timeout, close it
1622        let timeout = state.common.current_time() + timeout;
1623        state.common.register_timer(timeout, |state| {
1624            if let TcpStateEnum::TimeWait(state) = state {
1625                let recv_buffer = state.connection.into_recv_buffer();
1626                let new_state =
1627                    ClosedState::new(state.common, recv_buffer, /* was_connected= */ true);
1628                new_state.into()
1629            } else {
1630                state
1631            }
1632        });
1633
1634        state
1635    }
1636}
1637
1638impl<X: Dependencies> TcpStateTrait<X> for TimeWaitState<X> {
1639    fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
1640        // Linux does not seem to send a RST packet if a "time-wait" socket is closed while having
1641        // data in the receive buffer, probably because the peer should be in the "closed" state by
1642        // this point
1643
1644        // if the connection receives any more data, it should send an RST
1645        self.connection.send_rst_if_recv_payload();
1646
1647        // we're already in the process of closing (active close)
1648        (self.into(), Ok(()))
1649    }
1650
1651    fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
1652        if how == Shutdown::Read || how == Shutdown::Both {
1653            self.connection.send_rst_if_recv_payload()
1654        }
1655
1656        if how == Shutdown::Write || how == Shutdown::Both {
1657            // we're already in the process of closing (active close)
1658        }
1659
1660        (self.into(), Ok(()))
1661    }
1662
1663    fn connect<T, E>(
1664        self,
1665        _remote_addr: SocketAddrV4,
1666        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
1667    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
1668        (self.into(), Err(ConnectError::AlreadyConnected))
1669    }
1670
1671    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
1672        (self.into(), Err(SendError::StreamClosed))
1673    }
1674
1675    fn recv(
1676        mut self,
1677        writer: impl Write,
1678        len: usize,
1679    ) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
1680        let rv = self.connection.recv(writer, len);
1681
1682        // the peer won't send any more data (it sent a FIN), so if there's no more data in the
1683        // buffer, inform the socket
1684        if matches!(rv, Err(RecvError::Empty)) {
1685            return (self.into(), Err(RecvError::StreamClosed));
1686        }
1687
1688        (self.into(), rv)
1689    }
1690
1691    fn push_packet(
1692        mut self,
1693        header: &TcpHeader,
1694        payload: Payload,
1695    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
1696        // make sure that the packet src/dst addresses are valid for this connection
1697        if !self.connection.packet_addrs_match(header) {
1698            // must drop the packet
1699            return (self.into(), Ok(0));
1700        }
1701
1702        // TODO: send RST for all packets?
1703        let pushed_len = match self.connection.push_packet(header, payload) {
1704            Ok(v) => v,
1705            Err(e) => return (self.into(), Err(e)),
1706        };
1707
1708        // if the connection was reset
1709        if self.connection.is_reset() {
1710            if header.flags.contains(TcpFlags::RST) {
1711                self.common.set_error_if_unset(TcpError::ResetReceived);
1712            }
1713
1714            let new_state = connection_was_reset(self.common, self.connection);
1715            return (new_state, Ok(pushed_len));
1716        }
1717
1718        (self.into(), Ok(pushed_len))
1719    }
1720
1721    fn pop_packet(
1722        mut self,
1723    ) -> (
1724        TcpStateEnum<X>,
1725        Result<(TcpHeader, Payload), PopPacketError>,
1726    ) {
1727        let rv = self.connection.pop_packet(self.common.current_time());
1728        (self.into(), rv)
1729    }
1730
1731    fn clear_error(&mut self) -> Option<TcpError> {
1732        self.common.error.take()
1733    }
1734
1735    fn poll(&self) -> PollState {
1736        let mut poll_state = PollState::CONNECTED;
1737
1738        // we've received a FIN
1739        poll_state.insert(PollState::RECV_CLOSED);
1740        if self.connection.recv_buf_has_data() {
1741            poll_state.insert(PollState::READABLE);
1742        }
1743
1744        // we've sent a FIN
1745        poll_state.insert(PollState::SEND_CLOSED);
1746        assert!(!poll_state.contains(PollState::WRITABLE));
1747
1748        if self.common.error.is_some() {
1749            poll_state.insert(PollState::ERROR);
1750        }
1751
1752        poll_state
1753    }
1754
1755    fn wants_to_send(&self) -> bool {
1756        self.connection.wants_to_send()
1757    }
1758
1759    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
1760        Some((self.connection.local_addr, self.connection.remote_addr))
1761    }
1762}
1763
1764impl<X: Dependencies> CloseWaitState<X> {
1765    fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
1766        Self { common, connection }
1767    }
1768}
1769
1770impl<X: Dependencies> TcpStateTrait<X> for CloseWaitState<X> {
1771    fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
1772        let new_state = if self.connection.recv_buf_has_data() {
1773            // send a RST if there is still data in the receive buffer
1774            reset_connection(self.common, self.connection).into()
1775        } else {
1776            // send a FIN packet
1777            self.connection.send_fin();
1778
1779            // if the connection receives any more data, it should send an RST
1780            self.connection.send_rst_if_recv_payload();
1781
1782            LastAckState::new(self.common, self.connection).into()
1783        };
1784
1785        (new_state, Ok(()))
1786    }
1787
1788    fn rst_close(self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
1789        let new_state = reset_connection(self.common, self.connection);
1790        (new_state.into(), Ok(()))
1791    }
1792
1793    fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
1794        if how == Shutdown::Read || how == Shutdown::Both {
1795            self.connection.send_rst_if_recv_payload()
1796        }
1797
1798        if how == Shutdown::Write || how == Shutdown::Both {
1799            // send a FIN packet
1800            self.connection.send_fin();
1801
1802            let new_state = LastAckState::new(self.common, self.connection);
1803            return (new_state.into(), Ok(()));
1804        }
1805
1806        (self.into(), Ok(()))
1807    }
1808
1809    fn connect<T, E>(
1810        self,
1811        _remote_addr: SocketAddrV4,
1812        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
1813    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
1814        (self.into(), Err(ConnectError::AlreadyConnected))
1815    }
1816
1817    fn send(
1818        mut self,
1819        reader: impl Read,
1820        len: usize,
1821    ) -> (TcpStateEnum<X>, Result<usize, SendError>) {
1822        let rv = self.connection.send(reader, len);
1823        (self.into(), rv)
1824    }
1825
1826    fn recv(
1827        mut self,
1828        writer: impl Write,
1829        len: usize,
1830    ) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
1831        let rv = self.connection.recv(writer, len);
1832
1833        // the peer won't send any more data (it sent a FIN), so if there's no more data in the
1834        // buffer, inform the socket
1835        if matches!(rv, Err(RecvError::Empty)) {
1836            return (self.into(), Err(RecvError::StreamClosed));
1837        }
1838
1839        (self.into(), rv)
1840    }
1841
1842    fn push_packet(
1843        mut self,
1844        header: &TcpHeader,
1845        payload: Payload,
1846    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
1847        // make sure that the packet src/dst addresses are valid for this connection
1848        if !self.connection.packet_addrs_match(header) {
1849            // must drop the packet
1850            return (self.into(), Ok(0));
1851        }
1852
1853        let pushed_len = match self.connection.push_packet(header, payload) {
1854            Ok(v) => v,
1855            Err(e) => return (self.into(), Err(e)),
1856        };
1857
1858        // if the connection was reset
1859        if self.connection.is_reset() {
1860            if header.flags.contains(TcpFlags::RST) {
1861                self.common.set_error_if_unset(TcpError::ResetReceived);
1862            }
1863
1864            let new_state = connection_was_reset(self.common, self.connection);
1865            return (new_state, Ok(pushed_len));
1866        }
1867
1868        (self.into(), Ok(pushed_len))
1869    }
1870
1871    fn pop_packet(
1872        mut self,
1873    ) -> (
1874        TcpStateEnum<X>,
1875        Result<(TcpHeader, Payload), PopPacketError>,
1876    ) {
1877        let rv = self.connection.pop_packet(self.common.current_time());
1878        (self.into(), rv)
1879    }
1880
1881    fn clear_error(&mut self) -> Option<TcpError> {
1882        self.common.error.take()
1883    }
1884
1885    fn poll(&self) -> PollState {
1886        let mut poll_state = PollState::CONNECTED;
1887
1888        if self.connection.send_buf_has_space() {
1889            poll_state.insert(PollState::WRITABLE);
1890        }
1891
1892        // we've received a FIN
1893        poll_state.insert(PollState::RECV_CLOSED);
1894        if self.connection.recv_buf_has_data() {
1895            poll_state.insert(PollState::READABLE);
1896        }
1897
1898        if self.common.error.is_some() {
1899            poll_state.insert(PollState::ERROR);
1900        }
1901
1902        poll_state
1903    }
1904
1905    fn wants_to_send(&self) -> bool {
1906        self.connection.wants_to_send()
1907    }
1908
1909    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
1910        Some((self.connection.local_addr, self.connection.remote_addr))
1911    }
1912}
1913
1914impl<X: Dependencies> LastAckState<X> {
1915    fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
1916        Self { common, connection }
1917    }
1918}
1919
1920impl<X: Dependencies> TcpStateTrait<X> for LastAckState<X> {
1921    fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
1922        let new_state = if self.connection.recv_buf_has_data() {
1923            // send a RST if there is still data in the receive buffer
1924            reset_connection(self.common, self.connection).into()
1925        } else {
1926            // if the connection receives any more data, it should send an RST
1927            self.connection.send_rst_if_recv_payload();
1928
1929            // we're already in the process of closing (passive close)
1930            self.into()
1931        };
1932
1933        (new_state, Ok(()))
1934    }
1935
1936    fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
1937        if how == Shutdown::Read || how == Shutdown::Both {
1938            self.connection.send_rst_if_recv_payload()
1939        }
1940
1941        if how == Shutdown::Write || how == Shutdown::Both {
1942            // we're already in the process of closing (passive close)
1943        }
1944
1945        (self.into(), Ok(()))
1946    }
1947
1948    fn connect<T, E>(
1949        self,
1950        _remote_addr: SocketAddrV4,
1951        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
1952    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
1953        (self.into(), Err(ConnectError::AlreadyConnected))
1954    }
1955
1956    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
1957        (self.into(), Err(SendError::StreamClosed))
1958    }
1959
1960    fn recv(
1961        mut self,
1962        writer: impl Write,
1963        len: usize,
1964    ) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
1965        let rv = self.connection.recv(writer, len);
1966
1967        // the peer won't send any more data (it sent a FIN), so if there's no more data in the
1968        // buffer, inform the socket
1969        if matches!(rv, Err(RecvError::Empty)) {
1970            return (self.into(), Err(RecvError::StreamClosed));
1971        }
1972
1973        (self.into(), rv)
1974    }
1975
1976    fn push_packet(
1977        mut self,
1978        header: &TcpHeader,
1979        payload: Payload,
1980    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
1981        // make sure that the packet src/dst addresses are valid for this connection
1982        if !self.connection.packet_addrs_match(header) {
1983            // must drop the packet
1984            return (self.into(), Ok(0));
1985        }
1986
1987        let pushed_len = match self.connection.push_packet(header, payload) {
1988            Ok(v) => v,
1989            Err(e) => return (self.into(), Err(e)),
1990        };
1991
1992        // if the connection was reset
1993        if self.connection.is_reset() {
1994            if header.flags.contains(TcpFlags::RST) {
1995                self.common.set_error_if_unset(TcpError::ResetReceived);
1996            }
1997
1998            let new_state = connection_was_reset(self.common, self.connection);
1999            return (new_state, Ok(pushed_len));
2000        }
2001
2002        // if received ACK, move to the "closed" state
2003        if self.connection.fin_was_acked() {
2004            let recv_buffer = self.connection.into_recv_buffer();
2005            let new_state =
2006                ClosedState::new(self.common, recv_buffer, /* was_connected= */ true);
2007            return (new_state.into(), Ok(pushed_len));
2008        }
2009
2010        (self.into(), Ok(pushed_len))
2011    }
2012
2013    fn pop_packet(
2014        mut self,
2015    ) -> (
2016        TcpStateEnum<X>,
2017        Result<(TcpHeader, Payload), PopPacketError>,
2018    ) {
2019        let rv = self.connection.pop_packet(self.common.current_time());
2020        (self.into(), rv)
2021    }
2022
2023    fn clear_error(&mut self) -> Option<TcpError> {
2024        self.common.error.take()
2025    }
2026
2027    fn poll(&self) -> PollState {
2028        let mut poll_state = PollState::CONNECTED;
2029
2030        // we've received a FIN
2031        poll_state.insert(PollState::RECV_CLOSED);
2032        if self.connection.recv_buf_has_data() {
2033            poll_state.insert(PollState::READABLE);
2034        }
2035
2036        // we've sent a FIN
2037        poll_state.insert(PollState::SEND_CLOSED);
2038        assert!(!poll_state.contains(PollState::WRITABLE));
2039
2040        if self.common.error.is_some() {
2041            poll_state.insert(PollState::ERROR);
2042        }
2043
2044        poll_state
2045    }
2046
2047    fn wants_to_send(&self) -> bool {
2048        self.connection.wants_to_send()
2049    }
2050
2051    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
2052        Some((self.connection.local_addr, self.connection.remote_addr))
2053    }
2054}
2055
2056impl<X: Dependencies> RstState<X> {
2057    /// All packets must contain `TcpFlags::RST`.
2058    fn new(common: Common<X>, rst_packets: LinkedList<TcpHeader>, was_connected: bool) -> Self {
2059        debug_assert!(rst_packets.iter().all(|x| x.flags.contains(TcpFlags::RST)));
2060        assert!(!rst_packets.is_empty());
2061
2062        Self {
2063            common,
2064            send_buffer: rst_packets,
2065            was_connected,
2066        }
2067    }
2068}
2069
2070impl<X: Dependencies> TcpStateTrait<X> for RstState<X> {
2071    fn close(self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
2072        // we're already in the process of closing; do nothing
2073        (self.into(), Ok(()))
2074    }
2075
2076    fn shutdown(self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
2077        if !self.was_connected {
2078            return (self.into(), Err(ShutdownError::NotConnected));
2079        }
2080
2081        if how == Shutdown::Read || how == Shutdown::Both {
2082            // we've been reset, so nothing to do
2083        }
2084
2085        if how == Shutdown::Write || how == Shutdown::Both {
2086            // we're already in the process of closing; do nothing
2087        }
2088
2089        // TODO: should we return an error or do nothing?
2090
2091        (self.into(), Ok(()))
2092    }
2093
2094    fn connect<T, E>(
2095        self,
2096        _remote_addr: SocketAddrV4,
2097        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
2098    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
2099        if self.was_connected {
2100            (self.into(), Err(ConnectError::AlreadyConnected))
2101        } else {
2102            (self.into(), Err(ConnectError::InvalidState))
2103        }
2104    }
2105
2106    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
2107        if self.was_connected {
2108            (self.into(), Err(SendError::StreamClosed))
2109        } else {
2110            (self.into(), Err(SendError::NotConnected))
2111        }
2112    }
2113
2114    fn recv(self, _writer: impl Write, _len: usize) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
2115        if self.was_connected {
2116            (self.into(), Err(RecvError::StreamClosed))
2117        } else {
2118            (self.into(), Err(RecvError::NotConnected))
2119        }
2120    }
2121
2122    fn push_packet(
2123        self,
2124        _header: &TcpHeader,
2125        _payload: Payload,
2126    ) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
2127        // do nothing; drop all packets received in this state
2128        (self.into(), Ok(0))
2129    }
2130
2131    fn pop_packet(
2132        mut self,
2133    ) -> (
2134        TcpStateEnum<X>,
2135        Result<(TcpHeader, Payload), PopPacketError>,
2136    ) {
2137        // if we're in this state we must have a packet queued
2138        let header = self.send_buffer.pop_front().unwrap();
2139        let packet = (header, Payload::default());
2140
2141        // we're only supposed to send RST packets in this state
2142        assert!(packet.0.flags.contains(TcpFlags::RST));
2143
2144        // if we have no more packets to send
2145        if self.send_buffer.is_empty() {
2146            let new_state = ClosedState::new(
2147                self.common,
2148                None,
2149                /* was_connected= */ self.was_connected,
2150            );
2151            return (new_state.into(), Ok(packet));
2152        }
2153
2154        (self.into(), Ok(packet))
2155    }
2156
2157    fn clear_error(&mut self) -> Option<TcpError> {
2158        self.common.error.take()
2159    }
2160
2161    fn poll(&self) -> PollState {
2162        let mut poll_state = PollState::RECV_CLOSED | PollState::SEND_CLOSED;
2163
2164        if self.common.error.is_some() {
2165            poll_state.insert(PollState::ERROR);
2166        }
2167
2168        if self.was_connected {
2169            poll_state.insert(PollState::CONNECTED);
2170        }
2171
2172        poll_state
2173    }
2174
2175    fn wants_to_send(&self) -> bool {
2176        // if we're in this state we must have a packet queued
2177        assert!(!self.send_buffer.is_empty());
2178        true
2179    }
2180
2181    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
2182        None
2183    }
2184}
2185
2186impl<X: Dependencies> ClosedState<X> {
2187    fn new(common: Common<X>, recv_buffer: Option<RecvQueue>, was_connected: bool) -> Self {
2188        let recv_buffer = recv_buffer.unwrap_or_else(|| RecvQueue::new(Seq::new(0)));
2189
2190        if !was_connected {
2191            assert!(recv_buffer.is_empty());
2192        }
2193
2194        Self {
2195            common,
2196            recv_buffer,
2197            was_connected,
2198        }
2199    }
2200}
2201
2202impl<X: Dependencies> TcpStateTrait<X> for ClosedState<X> {
2203    fn close(self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
2204        // already closed; do nothing
2205        (self.into(), Ok(()))
2206    }
2207
2208    fn shutdown(self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
2209        if !self.was_connected {
2210            return (self.into(), Err(ShutdownError::NotConnected));
2211        }
2212
2213        if how == Shutdown::Read || how == Shutdown::Both {
2214            // we've been reset, so nothing to do
2215        }
2216
2217        if how == Shutdown::Write || how == Shutdown::Both {
2218            // we're already in the process of closing; do nothing
2219        }
2220
2221        // TODO: should we return an error or do nothing?
2222
2223        (self.into(), Ok(()))
2224    }
2225
2226    fn connect<T, E>(
2227        self,
2228        _remote_addr: SocketAddrV4,
2229        _associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
2230    ) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
2231        if self.was_connected {
2232            (self.into(), Err(ConnectError::AlreadyConnected))
2233        } else {
2234            (self.into(), Err(ConnectError::InvalidState))
2235        }
2236    }
2237
2238    fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
2239        if !self.was_connected {
2240            return (self.into(), Err(SendError::NotConnected));
2241        }
2242
2243        (self.into(), Err(SendError::StreamClosed))
2244    }
2245
2246    fn recv(
2247        mut self,
2248        writer: impl Write,
2249        len: usize,
2250    ) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
2251        if !self.was_connected {
2252            return (self.into(), Err(RecvError::NotConnected));
2253        }
2254
2255        if self.recv_buffer.is_empty() {
2256            return (self.into(), Err(RecvError::StreamClosed));
2257        }
2258
2259        let rv = self.recv_buffer.read(writer, len).map_err(RecvError::Io);
2260
2261        (self.into(), rv)
2262    }
2263
2264    fn clear_error(&mut self) -> Option<TcpError> {
2265        self.common.error.take()
2266    }
2267
2268    fn poll(&self) -> PollState {
2269        let mut poll_state = PollState::CLOSED;
2270
2271        poll_state.insert(PollState::RECV_CLOSED);
2272        if !self.recv_buffer.is_empty() {
2273            poll_state.insert(PollState::READABLE);
2274        }
2275
2276        poll_state.insert(PollState::SEND_CLOSED);
2277        assert!(!poll_state.contains(PollState::WRITABLE));
2278
2279        if self.was_connected {
2280            poll_state.insert(PollState::CONNECTED);
2281        }
2282
2283        if self.common.error.is_some() {
2284            poll_state.insert(PollState::ERROR);
2285        }
2286
2287        poll_state
2288    }
2289
2290    fn wants_to_send(&self) -> bool {
2291        false
2292    }
2293
2294    fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
2295        None
2296    }
2297}
2298
2299/// Reset the connection, get the resulting RST packet, and return a new `RstState` that will send
2300/// this RST packet.
2301fn reset_connection<X: Dependencies>(
2302    common: Common<X>,
2303    mut connection: Connection<X::Instant>,
2304) -> RstState<X> {
2305    connection.send_rst();
2306
2307    let new_state = connection_was_reset(common, connection);
2308
2309    let TcpStateEnum::Rst(new_state) = new_state else {
2310        panic!("We called `send_rst()` above but aren't now in the \"rst\" state: {new_state:?}");
2311    };
2312
2313    new_state
2314}
2315
2316/// For a connection that was reset (either by us or by the peer), check if it has a remaining RST
2317/// packet to send, and return a new `RstState` that will send this RST packet or a new
2318/// `ClosedState` if not.
2319fn connection_was_reset<X: Dependencies>(
2320    mut common: Common<X>,
2321    mut connection: Connection<X::Instant>,
2322) -> TcpStateEnum<X> {
2323    assert!(connection.is_reset());
2324
2325    let now = common.current_time();
2326
2327    // check if there's an RST packet to send
2328    if let Ok((header, payload)) = connection.pop_packet(now) {
2329        assert!(payload.is_empty());
2330        debug_assert!(connection.pop_packet(now).is_err());
2331
2332        common.set_error_if_unset(TcpError::ResetSent);
2333
2334        let rst_packets = [header].into_iter().collect();
2335        RstState::new(common, rst_packets, /* was_connected= */ true).into()
2336    } else {
2337        // the receive buffer is cleared when a connection is reset, which is why we pass `None`
2338        ClosedState::new(common, None, /* was_connected= */ true).into()
2339    }
2340}