tcp/
connection.rs

1use bytes::{Buf, Bytes};
2use std::io::{Read, Write};
3use std::net::SocketAddrV4;
4
5use crate::buffer::{RecvQueue, Segment};
6use crate::seq::{Seq, SeqRange};
7use crate::util::time::Instant;
8use crate::window_scaling::WindowScaling;
9use crate::{
10    Ipv4Header, Payload, PopPacketError, PushPacketError, RecvError, SendError, TcpConfig,
11    TcpFlags, TcpHeader,
12};
13
14/// Information for a TCP connection. Equivalent to the Transmission Control Block (TCB).
15#[derive(Debug)]
16pub(crate) struct Connection<I: Instant> {
17    pub(crate) config: TcpConfig,
18    pub(crate) local_addr: SocketAddrV4,
19    pub(crate) remote_addr: SocketAddrV4,
20    pub(crate) send: ConnectionSend<I>,
21    pub(crate) recv: Option<ConnectionRecv>,
22    pub(crate) need_to_ack: bool,
23    pub(crate) last_advertised_window: Option<u32>,
24    pub(crate) window_scaling: WindowScaling,
25    pub(crate) send_rst_if_recv_payload: bool,
26    pub(crate) is_reset: bool,
27    pub(crate) need_to_send_rst: bool,
28}
29
30impl<I: Instant> Connection<I> {
31    /// The max number of bytes allowed in the send and receive buffers. These should be made
32    /// dynamic in the future.
33    const SEND_BUF_MAX: usize = 100_000;
34    const RECV_BUF_MAX: u32 = 100_000;
35
36    pub fn new(
37        local_addr: SocketAddrV4,
38        remote_addr: SocketAddrV4,
39        send_initial_seq: Seq,
40        config: TcpConfig,
41    ) -> Self {
42        let mut rv = Self {
43            config,
44            local_addr,
45            remote_addr,
46            send: ConnectionSend::new(send_initial_seq),
47            recv: None,
48            need_to_ack: true,
49            last_advertised_window: None,
50            window_scaling: WindowScaling::new(),
51            send_rst_if_recv_payload: false,
52            is_reset: false,
53            need_to_send_rst: false,
54        };
55
56        // disable window scaling if it's disabled in the config
57        if !rv.config.window_scaling_enabled {
58            rv.window_scaling.disable();
59        }
60
61        rv
62    }
63
64    pub fn into_recv_buffer(self) -> Option<RecvQueue> {
65        if let Some(recv) = self.recv {
66            return Some(recv.buffer);
67        }
68
69        None
70    }
71
72    /// Returns `true` if the packet header src/dst addresses match this connection.
73    pub fn packet_addrs_match(&self, header: &TcpHeader) -> bool {
74        header.src() == self.remote_addr && header.dst() == self.local_addr
75    }
76
77    pub fn send_fin(&mut self) {
78        self.send.buffer.add_fin();
79        self.send.is_closed = true;
80    }
81
82    pub fn send_rst(&mut self) {
83        self.need_to_send_rst = true;
84        self.is_reset = true;
85    }
86
87    /// If any new payload bytes are received, the connection will be reset.
88    pub fn send_rst_if_recv_payload(&mut self) {
89        self.send_rst_if_recv_payload = true;
90    }
91
92    pub fn send(&mut self, reader: impl Read, len: usize) -> Result<usize, SendError> {
93        // if the buffer is full
94        if !self.send_buf_has_space() {
95            return Err(SendError::Full);
96        }
97
98        let send_buffer_len = self.send.buffer.len() as usize;
99        let send_buffer_space = Self::SEND_BUF_MAX.saturating_sub(send_buffer_len);
100
101        let len = std::cmp::min(len, send_buffer_space);
102        if let Err(e) = self.send.buffer.add_data(reader, len) {
103            return Err(SendError::Io(e));
104        }
105
106        Ok(len)
107    }
108
109    pub fn recv(&mut self, writer: impl Write, len: usize) -> Result<usize, RecvError> {
110        let recv = self.recv.as_mut().unwrap();
111
112        if recv.buffer.is_empty() {
113            return Err(RecvError::Empty);
114        }
115
116        recv.buffer.read(writer, len).map_err(RecvError::Io)
117    }
118
119    pub fn push_packet(
120        &mut self,
121        header: &TcpHeader,
122        payload: Payload,
123    ) -> Result<u32, PushPacketError> {
124        if self.is_reset {
125            panic!(
126                "The connection has already been reset, so why are we being given more packets?"
127            );
128        }
129
130        // process RST packets
131        if header.flags.contains(TcpFlags::RST) {
132            let seq = Seq::new(header.seq);
133            let recv_window = self.recv_window();
134
135            // TODO: figure out how to properly handle weird RST packets (for example RST packets
136            // with payload data)
137
138            let Some(recv_window) = recv_window else {
139                // we haven't received a SYN yet, so we'll trust the RST
140                self.is_reset = true;
141                return Ok(0);
142            };
143
144            // RFC 9293 3.10.7.4.:
145            // > If the RCV.WND is zero, no segments will be acceptable, but special allowance
146            // > should be made to accept valid ACKs, URGs, and RSTs.
147            if seq == recv_window.start {
148                // RFC 9293 3.10.7.4.:
149                // > If the RST bit is set and the sequence number exactly matches the next expected
150                // > sequence number (RCV.NXT), then TCP endpoints MUST reset the connection in the
151                // > manner prescribed below according to the connection state.
152
153                self.is_reset = true;
154                return Ok(0);
155            }
156
157            if recv_window.contains(seq) {
158                // RFC 9293 3.10.7.4.:
159                // > If the RST bit is set and the sequence number does not exactly match the next
160                // > expected sequence value, yet is within the current receive window, TCP
161                // > endpoints MUST send an acknowledgment (challenge ACK):
162                // >
163                // > <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
164                // >
165                // > After sending the challenge ACK, TCP endpoints MUST drop the unacceptable
166                // > segment and stop processing the incoming packet further.
167
168                // TODO: Setting `need_to_ack` to true isn't enough to send an acknowledgement
169                // exactly as described above, since the next `pop_packet` may next try to
170                // retransmit something which would have a different sequence number. But not sure
171                // if this really matters in practice since the peer would receive the packet and
172                // send another RST packet based on the ACK value we send.
173
174                self.need_to_ack = true;
175                return Ok(0);
176            }
177
178            // RFC 9293 3.10.7.4.:
179            // > If the RST bit is set and the sequence number is outside the current receive
180            // > window, silently drop the segment.
181
182            return Ok(0);
183        }
184
185        // process the first SYN packet
186        if self.recv.is_none() && header.flags.contains(TcpFlags::SYN) {
187            // we needed to know the sender's initial sequence number before we could initialize the
188            // receiving part of the connection
189            let seq = Seq::new(header.seq);
190            self.recv = Some(ConnectionRecv::new(seq));
191
192            self.window_scaling.received_syn(header.window_scale);
193        }
194
195        // We need to keep track of if the original packet had the SYN flag set, even if we trim a
196        // old retransmitted SYN flag from the packet below. If it was sent with a SYN flag, then we
197        // must not apply the window scale to the window size in the packet, even if the SYN's
198        // sequence number isn't within the receive window.
199        //
200        // RFC 7323 2.2.:
201        // > The window field in a segment where the SYN bit is set (i.e., a <SYN> or <SYN,ACK>)
202        // > MUST NOT be scaled.
203        //
204        // TODO: be careful about this if we support a reassembly queue in the future
205        let original_packet_had_syn = header.flags.contains(TcpFlags::SYN);
206
207        // trim the segment so that it only contains data/flags that fit within the receive window
208        let recv_window = self.recv_window().unwrap();
209        let Some((header, payload)) = trim_segment(header, payload, &recv_window) else {
210            // the sequence range of the segment does not overlap with the receive window, so we
211            // must drop the packet and send an ACK
212
213            self.need_to_ack = true;
214            return Ok(0);
215        };
216
217        let Some(recv) = self.recv.as_mut() else {
218            // we received a non-SYN packet before the first SYN packet
219            self.send_rst();
220            return Ok(0);
221        };
222
223        // if we've been told to send a RST when we receive new payload data, and we did receive new
224        // payload data
225        if self.send_rst_if_recv_payload && !payload.is_empty() {
226            self.send_rst();
227            return Ok(0);
228        }
229
230        // if we've previously received a FIN packet, and now we've received a payload/SYN/FIN
231        // packet that is within the receive window
232        if recv.is_closed
233            && (!payload.is_empty() || header.flags.intersects(TcpFlags::SYN | TcpFlags::FIN))
234        {
235            self.send_rst();
236            return Ok(0);
237        }
238
239        let mut pushed_len = 0;
240        // the receive buffer's initial next sequence number; useful so we can check if we need to
241        // acknowledge or not
242        let initial_seq = recv.buffer.next_seq();
243
244        if !recv.is_closed {
245            if header.flags.contains(TcpFlags::SYN) {
246                if recv.buffer.syn_added() {
247                    // this is the second SYN we've received
248
249                    // TODO: We can follow RFC 793 or RFC 5961 here. 793 is probably easiest, and we
250                    // should send an RST and move to the "closed" state.
251
252                    self.send_rst();
253                    return Ok(0);
254                }
255
256                recv.buffer.add_syn();
257            }
258
259            let syn_len = if header.flags.contains(TcpFlags::SYN) {
260                1
261            } else {
262                0
263            };
264
265            let payload_len = payload.len();
266            let payload_seq = (payload_len != 0).then_some(Seq::new(header.seq) + syn_len);
267            let fin_seq = header
268                .flags
269                .contains(TcpFlags::FIN)
270                .then_some(Seq::new(header.seq) + syn_len + payload_len);
271
272            if let Some(payload_seq) = payload_seq {
273                if payload_seq == recv.buffer.next_seq() {
274                    pushed_len += payload.len();
275                    for chunk in payload.0 {
276                        recv.buffer.add(chunk);
277                    }
278                } else {
279                    // TODO: store (truncated?) out-of-order packet
280                }
281            }
282
283            if let Some(fin_seq) = fin_seq {
284                if fin_seq == recv.buffer.next_seq() {
285                    recv.buffer.add_fin();
286                    recv.is_closed = true;
287                } else {
288                    // TODO: store (truncated?) out of order packet
289                }
290            }
291        }
292
293        // we've added to the receive buffer (payload, syn, or fin), so we need to send an
294        // acknowledgement
295        if recv.buffer.next_seq() != initial_seq {
296            self.need_to_ack = true;
297        }
298
299        // update the send window, applying the window scale shift only if it wasn't a SYN packet
300        // TODO: should we still update the window if the ACK was not in the valid range?
301        if original_packet_had_syn {
302            self.send.window = u32::from(header.window_size);
303        } else {
304            self.send.window =
305                u32::from(header.window_size) << self.window_scaling.send_window_scale_shift();
306        }
307
308        if header.flags.contains(TcpFlags::ACK) {
309            let valid_ack_range = SeqRange::new(
310                self.send.buffer.start_seq() + 1,
311                self.send.buffer.next_seq() + 1,
312            );
313
314            if valid_ack_range.contains(Seq::new(header.ack)) {
315                // the SYN is always first, so if a new sequence number has been acknowledged, then
316                // either it's acknowledging the SYN, or the SYN has been acknowledged in the past
317                if Seq::new(header.ack) != self.send.buffer.start_seq() {
318                    self.send.syn_acked = true;
319                }
320
321                self.send.buffer.advance_start(Seq::new(header.ack));
322            }
323        }
324
325        Ok(pushed_len)
326    }
327
328    pub fn pop_packet(&mut self, now: I) -> Result<(TcpHeader, Payload), PopPacketError> {
329        let (seq_range, mut flags, payload) =
330            self.next_segment().ok_or(PopPacketError::NoPacket)?;
331
332        // After this point we must always send a packet. If we don't, then we're not being
333        // consistent with `self.wants_to_send()`.
334        debug_assert!(self.wants_to_send());
335
336        let header_ack = if let Some(recv) = self.recv.as_ref() {
337            // we've received a SYN packet (either now or in the past), so should always acknowledge
338            flags.insert(TcpFlags::ACK);
339            recv.buffer.next_seq()
340        } else {
341            // not setting the ACK flag, so this can probably be anything
342            Seq::new(0)
343        };
344
345        let header_window_size;
346        let header_window_scale;
347
348        if flags.contains(TcpFlags::SYN) {
349            if self.window_scaling.can_send_window_scale() {
350                // The receive buffer capacity at the time the SYN is sent decides the window
351                // scaling to use. This effectively limits future receive buffer capacity increases
352                // since the receive window will forever have a ceiling set here by the window
353                // scale.
354                let shift = WindowScaling::scale_shift_for_max_window(self.recv_buffer_capacity());
355                header_window_scale = Some(shift);
356            } else {
357                header_window_scale = None;
358            }
359
360            // don't actually apply this window scale in the SYN packet
361            //
362            // RFC 7323 2.2.:
363            // > The window field in a segment where the SYN bit is set (i.e., a <SYN> or <SYN,ACK>)
364            // > MUST NOT be scaled.
365            header_window_size = self.recv_window_len();
366            self.last_advertised_window = Some(header_window_size);
367
368            // Make sure we're sending a valid 2-byte window size. We haven't called
369            // `WindowScaling::sent_syn()` yet, so `Self::recv_window_len()` should not have
370            // returned a window size larger than `u16::MAX`.
371            debug_assert!(header_window_size <= u16::MAX as u32);
372
373            self.window_scaling.sent_syn(header_window_scale);
374        } else {
375            // don't send a window scale
376            //
377            // RFC 7323 2.1.:
378            // > The exponent of the scale factor is carried in a TCP option, Window Scale. This
379            // > option is sent only in a <SYN> segment (a segment with the SYN bit on), [...]
380            header_window_scale = None;
381
382            let shift = self.window_scaling.recv_window_scale_shift();
383            header_window_size = self.recv_window_len() >> shift;
384
385            // this is the value the peer will see (precision is intentionally lost due to bit-shift)
386            self.last_advertised_window = Some(header_window_size << shift);
387        }
388
389        let header = TcpHeader {
390            ip: Ipv4Header {
391                src: *self.local_addr.ip(),
392                dst: *self.remote_addr.ip(),
393            },
394            flags,
395            src_port: self.local_addr.port(),
396            dst_port: self.remote_addr.port(),
397            seq: seq_range.start.into(),
398            ack: header_ack.into(),
399            window_size: header_window_size.try_into().unwrap(),
400            selective_acks: None,
401            window_scale: header_window_scale,
402            timestamp: None,
403            timestamp_echo: None,
404        };
405
406        // we're sending the most up-to-date acknowledgement
407        self.need_to_ack = false;
408
409        // inform the buffer that we transmitted this segment
410        self.send.buffer.mark_as_transmitted(seq_range.end, now);
411
412        if header.flags.contains(TcpFlags::RST) {
413            assert!(self.need_to_send_rst);
414            self.need_to_send_rst = false;
415        }
416
417        Ok((header, payload))
418    }
419
420    /// Returns a segment that is ready to send. This may be a data segment (a segment containing a
421    /// SYN/FIN flag and/or payload data), a RST segment, or an empty segment. Even if this returns
422    /// an empty segment, it must be sent with the correct acknowledgement number, window size, etc
423    /// as it may represent an acknowledgement or window update.
424    fn next_segment(&self) -> Option<(SeqRange, TcpFlags, Payload)> {
425        // should be inlined
426        self._next_segment()
427    }
428
429    /// Returns true if ready to send a packet.
430    pub fn wants_to_send(&self) -> bool {
431        // should be inlined
432        self._next_segment().is_some()
433    }
434
435    /// Do not call directly. Use either `next_segment()` or `wants_to_send()`.
436    ///
437    /// Since `wants_to_send()` is only interested in whether the result is `Some`, by inlining this
438    /// function the compiler should hopefully optimize it to remove unnecessary values that will be
439    /// immediately discarded. I'm uncertain whether there's really much that can be optimized here
440    /// though, but splitting it into functions will at least help us notice if either function is
441    /// showing up in a profile/heatmap. Since the function is large and is `inline(always)` we only
442    /// call it from two functions, `next_segment()` and `wants_to_send()`.
443    #[inline(always)]
444    fn _next_segment(&self) -> Option<(SeqRange, TcpFlags, Payload)> {
445        if self.need_to_send_rst {
446            let seq = self
447                .send
448                .buffer
449                .next_not_transmitted(0)
450                .map(|x| x.0)
451                .unwrap_or(self.send.buffer.next_seq());
452
453            let seq_range = SeqRange::new(seq, seq);
454            return Some((seq_range, TcpFlags::RST, Payload::default()));
455        }
456
457        // if the connection has been reset and we don't need to send a RST packet, never send any
458        // future packets
459        if self.is_reset {
460            return None;
461        }
462
463        let (seq_range, syn_fin_flags, payload) = 'packet: {
464            // if we have syn/fin/payload data to send
465            if let Some((seq_range, syn_fin_flags, payload)) = self.next_data_segment() {
466                break 'packet (seq_range, syn_fin_flags, payload);
467            }
468
469            let mut send_empty_packet = false;
470
471            // do we need to send an acknowledgement?
472            if self.need_to_ack {
473                send_empty_packet = true;
474            }
475
476            // do we need to send a window update?
477            if let Some(window) = self.recv_window().map(|x| x.len()) {
478                let window_scale = self.window_scaling.recv_window_scale_shift();
479
480                let apparent_window = window >> window_scale << window_scale;
481
482                if self.last_advertised_window != Some(apparent_window) {
483                    send_empty_packet = true;
484                }
485            }
486
487            if send_empty_packet {
488                // use the sequence number of the next unsent message if we have one buffered,
489                // otherwise get the next sequence number from the buffer
490                let seq = self
491                    .send
492                    .buffer
493                    .next_not_transmitted(0)
494                    .map(|x| x.0)
495                    .unwrap_or(self.send.buffer.next_seq());
496
497                let seq_range = SeqRange::new(seq, seq);
498                break 'packet (seq_range, TcpFlags::empty(), Payload::default());
499            }
500
501            return None;
502        };
503
504        // if not sending a SYN packet and window scaling isn't yet confirmed
505        if !syn_fin_flags.contains(TcpFlags::SYN) && !self.window_scaling.is_configured() {
506            // we cannot send a non-SYN packet since non-SYN packets must apply window scaling, but
507            // we haven't yet confirmed if we're using window scaling or not
508            return None;
509        }
510
511        Some((seq_range, syn_fin_flags, payload))
512    }
513
514    /// Returns a data segment that is ready to send. This is a segment containing a SYN/FIN flag
515    /// and/or payload data. Even if this returns `None`, we may still want to send some other
516    /// segment such as an acknowledgement or window update (see `Self::next_segment`).
517    fn next_data_segment(&self) -> Option<(SeqRange, TcpFlags, Payload)> {
518        let send_window = self.send_window();
519
520        let mut chunks = Vec::new();
521        let mut syn_fin_flags = TcpFlags::empty();
522        let mut seq_start = None;
523        let mut seq_len = 0;
524        let mut payload_bytes_len = 0;
525
526        // roughly represents the MSS
527        // TODO: handle the MSS properly
528        const MAX_BYTES_PER_PACKET: u32 = 1500;
529
530        // do we have syn/fin/payload data to send?
531        while let Some((seq, segment)) = self.send.buffer.next_not_transmitted(seq_len) {
532            // if no bytes of this segment fit within the send window
533            if !send_window.contains(seq) {
534                break;
535            }
536
537            // if we can't send any more payload bytes
538            if payload_bytes_len == MAX_BYTES_PER_PACKET {
539                break;
540            }
541
542            // if this is the first returned segment, keep track of the start
543            if seq_start.is_none() {
544                seq_start = Some(seq);
545            }
546
547            match segment {
548                Segment::Syn => {
549                    syn_fin_flags.insert(TcpFlags::SYN);
550                    seq_len += segment.len();
551                }
552                Segment::Fin => {
553                    syn_fin_flags.insert(TcpFlags::FIN);
554                    seq_len += segment.len();
555                }
556                Segment::Data(mut chunk) => {
557                    let allowed_payload_len =
558                        MAX_BYTES_PER_PACKET.saturating_sub(payload_bytes_len);
559                    let allowed_seq_len = send_window.end - seq;
560                    let allowed_len = std::cmp::min(allowed_payload_len, allowed_seq_len);
561
562                    chunk.truncate(std::cmp::min(chunk.len(), allowed_len.try_into().unwrap()));
563
564                    let chunk_len: u32 = chunk.len().try_into().unwrap();
565                    payload_bytes_len += chunk_len;
566                    seq_len += chunk_len;
567
568                    chunks.push(chunk);
569                }
570            };
571
572            // we shouldn't be sending more than allowed
573            debug_assert!(payload_bytes_len <= MAX_BYTES_PER_PACKET);
574        }
575
576        if !chunks.is_empty() || !syn_fin_flags.is_empty() {
577            let seq_start = seq_start.unwrap();
578            let seq_range = SeqRange::new(seq_start, seq_start + seq_len);
579            return Some((seq_range, syn_fin_flags, Payload(chunks)));
580        }
581
582        None
583    }
584
585    /// Returns true if we received a RST packet, or if we want to send a RST packet.
586    pub fn is_reset(&self) -> bool {
587        self.is_reset
588    }
589
590    /// Returns true if we received a SYN packet from the peer.
591    pub fn received_syn(&self) -> bool {
592        // we don't construct the receive part of the connection until we've received the SYN
593        self.recv.is_some()
594    }
595
596    /// Returns true if we received a FIN packet from the peer.
597    pub fn received_fin(&self) -> bool {
598        self.recv.as_ref().map(|x| x.is_closed).unwrap_or(false)
599    }
600
601    /// Returns true if the peer acknowledged the SYN packet we sent.
602    pub fn syn_was_acked(&self) -> bool {
603        self.send.syn_acked
604    }
605
606    /// Returns true if the peer acknowledged the FIN packet we sent.
607    pub fn fin_was_acked(&self) -> bool {
608        self.send.is_closed && self.send.buffer.start_seq() == self.send.buffer.next_seq()
609    }
610
611    /// Returns true if the send buffer has space available. Does not consider whether the
612    /// connection is open/closed, either due to FIN packets or `shutdown()`.
613    pub fn send_buf_has_space(&self) -> bool {
614        let send_buffer_len = self.send.buffer.len() as usize;
615
616        send_buffer_len < Self::SEND_BUF_MAX
617    }
618
619    /// Returns true if the recv buffer has data to read. Does not consider whether the connection
620    /// is open/closed, either due to FIN packets or `shutdown()`.
621    pub fn recv_buf_has_data(&self) -> bool {
622        let is_empty = self
623            .recv
624            .as_ref()
625            .map(|x| x.buffer.is_empty())
626            .unwrap_or(true);
627        !is_empty
628    }
629
630    pub(crate) fn send_window(&self) -> SeqRange {
631        // the buffer stores unsent/unacked data, so the buffer starts at the lowest unacked
632        // sequence number
633        let window_left = self.send.buffer.start_seq();
634        SeqRange::new(window_left, window_left + self.send.window)
635    }
636
637    /// Returns the size of the receive window. This is useful when we only need the size of the
638    /// window and we may not have received the SYN packet yet, so cannot construct the range.
639    pub(crate) fn recv_window_len(&self) -> u32 {
640        if let Some(recv_window) = self.recv_window() {
641            return recv_window.len();
642        }
643
644        let window_max = self.window_scaling.recv_window_max();
645        std::cmp::min(self.recv_buffer_capacity(), window_max)
646    }
647
648    /// Returns the receive window if we've received a SYN packet.
649    pub(crate) fn recv_window(&self) -> Option<SeqRange> {
650        let recv = self.recv.as_ref()?;
651        let window_left = recv.buffer.next_seq();
652        let window_max = self.window_scaling.recv_window_max();
653        let window_len = self
654            .recv_buffer_capacity()
655            .saturating_sub(recv.buffer.len());
656        let window_len = std::cmp::min(window_len, window_max);
657        Some(SeqRange::new(window_left, window_left + window_len))
658    }
659
660    /// The total capacity of the receive buffer.
661    fn recv_buffer_capacity(&self) -> u32 {
662        Self::RECV_BUF_MAX
663    }
664}
665
666#[derive(Debug)]
667pub(crate) struct ConnectionSend<I: Instant> {
668    pub(crate) buffer: super::buffer::SendQueue<I>,
669    pub(crate) window: u32,
670    pub(crate) is_closed: bool,
671    pub(crate) syn_acked: bool,
672}
673
674impl<I: Instant> ConnectionSend<I> {
675    pub fn new(initial_seq: Seq) -> Self {
676        Self {
677            buffer: super::buffer::SendQueue::new(initial_seq),
678            // we don't know the peer's receive window, so choose something conservative
679            window: 2048,
680            is_closed: false,
681            syn_acked: false,
682        }
683    }
684}
685
686#[derive(Debug)]
687pub(crate) struct ConnectionRecv {
688    pub(crate) buffer: super::buffer::RecvQueue,
689    pub(crate) is_closed: bool,
690}
691
692impl ConnectionRecv {
693    pub fn new(initial_seq: Seq) -> Self {
694        Self {
695            buffer: super::buffer::RecvQueue::new(initial_seq),
696            is_closed: false,
697        }
698    }
699}
700
701/// Trims the segment `header` and `payload` such that only bytes in the sequence `range` remain.
702/// This may modify the segment sequence number, SYN/FIN flags, or payload.
703fn trim_segment(
704    header: &TcpHeader,
705    payload: Payload,
706    range: &SeqRange,
707) -> Option<(TcpHeader, Payload)> {
708    let seq = Seq::new(header.seq);
709    let syn_len = if header.flags.contains(TcpFlags::SYN) {
710        1
711    } else {
712        0
713    };
714    let fin_len = if header.flags.contains(TcpFlags::FIN) {
715        1
716    } else {
717        0
718    };
719    let payload_len = payload.len();
720
721    let header_range = SeqRange::new(seq, seq + syn_len + payload_len + fin_len);
722    let intersection = header_range.intersection(range)?;
723
724    if intersection == header_range {
725        // in the common case where the segment is completely contained within the range, return
726        // early without any modifications
727        return Some((*header, payload));
728    }
729
730    let include_syn = syn_len == 1 && range.contains(header_range.start);
731    let include_fin = fin_len == 1 && range.contains(header_range.end - 1);
732
733    let payload_seq = seq + syn_len;
734    let new_payload = match trim_payload(payload_seq, payload, range) {
735        Some((new_seq, new_payload)) => {
736            assert_eq!(
737                new_seq,
738                intersection.start + if include_syn { 1 } else { 0 }
739            );
740            new_payload
741        }
742        None => Payload::default(),
743    };
744
745    let mut new_flags = header.flags;
746    new_flags.set(TcpFlags::SYN, include_syn);
747    new_flags.set(TcpFlags::FIN, include_fin);
748
749    let new_header = TcpHeader {
750        seq: intersection.start.into(),
751        flags: new_flags,
752        ..*header
753    };
754
755    Some((new_header, new_payload))
756}
757
758/// Trims `payload`, which starts at a given `seq` number, such that only bytes in the sequence
759/// `range` remain.
760///
761/// If the two ranges do not intersect `None` will be returned. A `None` is also returned if the
762/// range intersects the payload twice, for example if the payload covers the range 100..200 and the
763/// given range covers 180..120, but this shouldn't occur for reasonable TCP sequence number ranges.
764/// The returned payload may be empty if the original `payload` was empty or the `range` was empty,
765/// but they still intersect according to [`SeqRange::intersection`].
766fn trim_payload(seq: Seq, mut payload: Payload, range: &SeqRange) -> Option<(Seq, Payload)> {
767    let payload_range = SeqRange::new(seq, seq + payload.len());
768    let intersection = payload_range.intersection(range)?;
769
770    if payload_range == intersection {
771        // in the common case where the payload is completely contained within the range, return
772        // early without any modifications
773        return Some((seq, payload));
774    }
775
776    // the sequence number of the current chunk
777    let mut seq_cursor = seq;
778
779    // we could use `retain` here and remove empty/out-of-bounds chunks, but it's simpler and
780    // probably faster to avoid shifting elements around and just leave empty chunks (and replace
781    // out-of-bounds chunks with empty chunks)
782    for chunk in &mut payload.0 {
783        let original_chunk_len = chunk.len().try_into().unwrap();
784
785        // `take` will replace the current chunk with an empty chunk
786        if let Some((_seq, new_chunk)) = trim_chunk(seq_cursor, std::mem::take(chunk), range) {
787            *chunk = new_chunk;
788        }
789
790        seq_cursor += original_chunk_len;
791    }
792
793    debug_assert_eq!(payload.len(), intersection.len());
794    Some((intersection.start, payload))
795}
796
797/// Trims `chunk`, which starts at a given `seq` number, such that only bytes in the sequence
798/// `range` remain.
799///
800/// If the two ranges do not intersect `None` will be returned. A `None` is also returned if the
801/// range intersects the chunk twice, for example if the chunk covers the range 100..200 and the
802/// given range covers 180..120, but this shouldn't occur for reasonable TCP sequence number ranges.
803/// The returned chunk may be empty if the original `chunk` was empty or the `range` was empty, but
804/// they still intersect according to [`SeqRange::intersection`].
805fn trim_chunk(seq: Seq, mut chunk: Bytes, range: &SeqRange) -> Option<(Seq, Bytes)> {
806    let chunk_range = SeqRange::new(seq, seq + chunk.len().try_into().unwrap());
807
808    let intersection = chunk_range.intersection(range)?;
809
810    let new_offset = intersection.start - seq;
811    let new_len = intersection.len();
812
813    let new_offset: usize = new_offset.try_into().unwrap();
814    let new_len: usize = new_len.try_into().unwrap();
815
816    // update the existing `Bytes` object rather than using `slice()` to avoid an atomic operation
817    chunk.advance(new_offset);
818    chunk.truncate(new_len);
819
820    Some((intersection.start, chunk))
821}
822
823#[cfg(test)]
824mod tests {
825    use super::*;
826
827    use std::net::Ipv4Addr;
828
829    // helper to make the tests fit on a single line
830    fn range(start: u32, end: u32) -> SeqRange {
831        SeqRange::new(Seq::new(start), Seq::new(end))
832    }
833
834    // helper to make the tests fit on a single line
835    fn seq(val: u32) -> Seq {
836        Seq::new(val)
837    }
838
839    // helper to make the tests fit on a single line
840    fn bytes<const N: usize>(x: &[u8; N]) -> Bytes {
841        Box::<[u8]>::from(x.as_slice()).into()
842    }
843
844    // helper to make the tests fit on a single line
845    macro_rules! payload {
846        () => {
847            Payload([].into())
848        };
849        ($($slices:literal),+) => {
850            {
851                let iter = ([$(&$slices[..]),+]).into_iter().map(|x| Bytes::copy_from_slice(&x));
852                Payload(iter.collect())
853            }
854        };
855    }
856
857    #[test]
858    fn test_trim_segment() {
859        fn test_trim(
860            flags: TcpFlags,
861            seq: Seq,
862            payload: impl Into<Payload>,
863            range: SeqRange,
864        ) -> Option<(TcpFlags, Seq, Bytes)> {
865            let header = TcpHeader {
866                ip: Ipv4Header {
867                    src: Ipv4Addr::UNSPECIFIED,
868                    dst: Ipv4Addr::UNSPECIFIED,
869                },
870                flags,
871                src_port: 0,
872                dst_port: 0,
873                seq: seq.into(),
874                ack: 0,
875                window_size: 0,
876                selective_acks: None,
877                window_scale: None,
878                timestamp: None,
879                timestamp_echo: None,
880            };
881
882            let (header, payload) = trim_segment(&header, payload.into(), &range)?;
883            let payload = payload.concat();
884
885            Some((header.flags, Seq::new(header.seq), payload))
886        }
887
888        const SYN: TcpFlags = TcpFlags::SYN;
889        const FIN: TcpFlags = TcpFlags::FIN;
890        const ACK: TcpFlags = TcpFlags::ACK;
891        const EMPTY: TcpFlags = TcpFlags::empty();
892
893        assert_eq!(test_trim(EMPTY, seq(0), bytes(b""), range(1, 1)), None);
894        assert_eq!(test_trim(EMPTY, seq(1), bytes(b""), range(0, 1)), None);
895        assert_eq!(
896            test_trim(EMPTY, seq(0), bytes(b""), range(0, 0)),
897            Some((EMPTY, seq(0), bytes(b""))),
898        );
899        assert_eq!(
900            test_trim(ACK, seq(0), bytes(b""), range(0, 0)),
901            Some((ACK, seq(0), bytes(b""))),
902        );
903        assert_eq!(
904            test_trim(EMPTY, seq(0), bytes(b"123"), range(0, 0)),
905            Some((EMPTY, seq(0), bytes(b""))),
906        );
907        assert_eq!(
908            test_trim(SYN, seq(0), bytes(b""), range(0, 0)),
909            Some((EMPTY, seq(0), bytes(b""))),
910        );
911        assert_eq!(
912            test_trim(FIN, seq(0), bytes(b""), range(0, 0)),
913            Some((EMPTY, seq(0), bytes(b""))),
914        );
915        assert_eq!(
916            test_trim(EMPTY, seq(0), bytes(b""), range(0, 2)),
917            Some((EMPTY, seq(0), bytes(b""))),
918        );
919        assert_eq!(
920            test_trim(EMPTY, seq(0), bytes(b"123"), range(0, 2)),
921            Some((EMPTY, seq(0), bytes(b"12"))),
922        );
923        assert_eq!(
924            test_trim(SYN, seq(0), bytes(b"123"), range(0, 2)),
925            Some((SYN, seq(0), bytes(b"1"))),
926        );
927        assert_eq!(
928            test_trim(SYN | FIN, seq(0), bytes(b"123"), range(0, 2)),
929            Some((SYN, seq(0), bytes(b"1"))),
930        );
931        assert_eq!(
932            test_trim(SYN | FIN, seq(0), bytes(b"123"), range(1, 2)),
933            Some((EMPTY, seq(1), bytes(b"1"))),
934        );
935        assert_eq!(
936            test_trim(SYN | FIN, seq(0), bytes(b"123"), range(1, 5)),
937            Some((FIN, seq(1), bytes(b"123"))),
938        );
939        assert_eq!(
940            test_trim(SYN | FIN, seq(0), bytes(b"123"), range(0, 1)),
941            Some((SYN, seq(0), bytes(b""))),
942        );
943        assert_eq!(
944            test_trim(SYN | FIN, seq(4), bytes(b"123"), range(0, 5)),
945            Some((SYN, seq(4), bytes(b""))),
946        );
947        assert_eq!(
948            test_trim(SYN | FIN | ACK, seq(3), bytes(b"123"), range(0, 5)),
949            Some((SYN | ACK, seq(3), bytes(b"1"))),
950        );
951    }
952
953    #[test]
954    fn test_trim_payload() {
955        fn test_trim(seq: Seq, payload: Payload, range: SeqRange) -> Option<(Seq, Vec<Bytes>)> {
956            let (seq, payload) = trim_payload(seq, payload, &range)?;
957            Some((seq, payload.0))
958        }
959
960        assert_eq!(
961            test_trim(seq(0), payload![b""], range(0, 0)),
962            Some((seq(0), payload![b""].0)),
963        );
964        assert_eq!(
965            test_trim(seq(0), payload![], range(0, 0)),
966            Some((seq(0), vec![])),
967        );
968        assert_eq!(test_trim(seq(1), payload![b""], range(0, 0)), None);
969        assert_eq!(test_trim(seq(1), payload![b""], range(0, 1)), None);
970        assert_eq!(
971            test_trim(seq(1), payload![b""], range(0, 2)),
972            Some((seq(1), payload![b""].0)),
973        );
974        assert_eq!(
975            test_trim(seq(0), payload![b"a"], range(0, 0)),
976            Some((seq(0), payload![b""].0)),
977        );
978        assert_eq!(
979            test_trim(seq(0), payload![b"a"], range(0, 1)),
980            Some((seq(0), payload![b"a"].0)),
981        );
982        assert_eq!(
983            test_trim(seq(0), payload![b"ab"], range(0, 1)),
984            Some((seq(0), payload![b"a"].0)),
985        );
986        assert_eq!(
987            test_trim(seq(0), payload![b"abcdefg"], range(2, 4)),
988            Some((seq(2), payload![b"cd"].0)),
989        );
990        assert_eq!(
991            test_trim(seq(3), payload![b"abcdefg"], range(2, 4)),
992            Some((seq(3), payload![b"a"].0)),
993        );
994        assert_eq!(
995            test_trim(seq(3), payload![b"abcdefg"], range(2, 20)),
996            Some((seq(3), payload![b"abcdefg"].0)),
997        );
998        assert_eq!(
999            test_trim(seq(3), payload![b"abcdefg"], range(9, 20)),
1000            Some((seq(9), payload![b"g"].0)),
1001        );
1002        assert_eq!(test_trim(seq(3), payload![b"abcdefg"], range(10, 20)), None);
1003
1004        // second test intersects twice, so returns `None`
1005        assert_eq!(
1006            test_trim(seq(5), payload![b"abcdefg"], range(8, 5)),
1007            Some((seq(8), payload![b"defg"].0)),
1008        );
1009        assert_eq!(test_trim(seq(5), payload![b"abcdefg"], range(8, 7)), None);
1010
1011        // cut off right edge
1012        assert_eq!(
1013            test_trim(seq(0), payload![b"a", b"bcd"], range(0, 3)),
1014            Some((seq(0), payload![b"a", b"bc"].0)),
1015        );
1016        // cut off left edge
1017        assert_eq!(
1018            test_trim(seq(0), payload![b"abc", b"d"], range(1, 4)),
1019            Some((seq(1), payload![b"bc", b"d"].0)),
1020        );
1021        // cut off left and right edge
1022        assert_eq!(
1023            test_trim(seq(0), payload![b"abc", b"def", b"ghi"], range(1, 8)),
1024            Some((seq(1), payload![b"bc", b"def", b"gh"].0)),
1025        );
1026        // cut off left and right chunks
1027        assert_eq!(
1028            test_trim(seq(0), payload![b"abc", b"def", b"ghi"], range(3, 6)),
1029            Some((seq(3), payload![b"", b"def", b""].0)),
1030        );
1031        // cut off left and right edges of same chunk
1032        assert_eq!(
1033            test_trim(seq(0), payload![b"abc", b"def", b"ghi"], range(4, 5)),
1034            Some((seq(4), payload![b"", b"e", b""].0)),
1035        );
1036
1037        assert_eq!(
1038            test_trim(
1039                seq(0),
1040                payload![b"", b"abc", b"", b"de", b"", b"f", b"", b"ghi"],
1041                range(3, 6)
1042            ),
1043            Some((
1044                seq(3),
1045                payload![b"", b"", b"", b"de", b"", b"f", b"", b""].0
1046            )),
1047        );
1048    }
1049
1050    #[test]
1051    fn test_trim_chunk() {
1052        fn test_trim(seq: Seq, chunk: Bytes, range: SeqRange) -> Option<(Seq, Bytes)> {
1053            trim_chunk(seq, chunk, &range)
1054        }
1055
1056        assert_eq!(
1057            test_trim(seq(0), bytes(b""), range(0, 0)),
1058            Some((seq(0), bytes(b""))),
1059        );
1060        assert_eq!(test_trim(seq(1), bytes(b""), range(0, 0)), None);
1061        assert_eq!(test_trim(seq(1), bytes(b""), range(0, 1)), None);
1062        assert_eq!(
1063            test_trim(seq(1), bytes(b""), range(0, 2)),
1064            Some((seq(1), bytes(b""))),
1065        );
1066        assert_eq!(
1067            test_trim(seq(0), bytes(b"a"), range(0, 0)),
1068            Some((seq(0), bytes(b""))),
1069        );
1070        assert_eq!(
1071            test_trim(seq(0), bytes(b"a"), range(0, 1)),
1072            Some((seq(0), bytes(b"a"))),
1073        );
1074        assert_eq!(
1075            test_trim(seq(0), bytes(b"ab"), range(0, 1)),
1076            Some((seq(0), bytes(b"a"))),
1077        );
1078        assert_eq!(
1079            test_trim(seq(0), bytes(b"abcdefg"), range(2, 4)),
1080            Some((seq(2), bytes(b"cd"))),
1081        );
1082        assert_eq!(
1083            test_trim(seq(3), bytes(b"abcdefg"), range(2, 4)),
1084            Some((seq(3), bytes(b"a"))),
1085        );
1086        assert_eq!(
1087            test_trim(seq(3), bytes(b"abcdefg"), range(2, 20)),
1088            Some((seq(3), bytes(b"abcdefg"))),
1089        );
1090        assert_eq!(
1091            test_trim(seq(3), bytes(b"abcdefg"), range(9, 20)),
1092            Some((seq(9), bytes(b"g"))),
1093        );
1094        assert_eq!(test_trim(seq(3), bytes(b"abcdefg"), range(10, 20)), None);
1095
1096        // second test intersects twice, so returns `None`
1097        assert_eq!(
1098            test_trim(seq(5), bytes(b"abcdefg"), range(8, 5)),
1099            Some((seq(8), bytes(b"defg"))),
1100        );
1101        assert_eq!(test_trim(seq(5), bytes(b"abcdefg"), range(8, 7)), None);
1102    }
1103}