tcp/
buffer.rs

1use std::collections::LinkedList;
2use std::io::{Read, Write};
3
4use bytes::{Buf, Bytes, BytesMut};
5
6use crate::seq::{Seq, SeqRange};
7use crate::util::time::Instant;
8
9#[derive(Debug)]
10pub(crate) struct SendQueue<T: Instant> {
11    segments: LinkedList<Segment>,
12    time_last_segment_sent: Option<T>,
13    // exclusive
14    transmitted_up_to: Seq,
15    // inclusive
16    start_seq: Seq,
17    // exclusive
18    end_seq: Seq,
19    fin_added: bool,
20    unused: BytesMut,
21}
22
23impl<T: Instant> SendQueue<T> {
24    pub fn new(initial_seq: Seq) -> Self {
25        let mut queue = Self {
26            segments: LinkedList::new(),
27            time_last_segment_sent: None,
28            transmitted_up_to: initial_seq,
29            start_seq: initial_seq,
30            end_seq: initial_seq,
31            fin_added: false,
32            unused: BytesMut::new(),
33        };
34
35        queue.add_syn();
36
37        queue
38    }
39
40    fn add_syn(&mut self) {
41        self.add_segment(Segment::Syn);
42    }
43
44    pub fn add_fin(&mut self) {
45        self.add_segment(Segment::Fin);
46    }
47
48    pub fn add_data(
49        &mut self,
50        mut reader: impl Read,
51        mut len: usize,
52    ) -> Result<(), std::io::Error> {
53        // These values shouldn't affect the tcp behaviour, only how the underlying bytes are
54        // allocated. The numbers are chosen arbitrarily.
55        const MAX_BYTES_PER_ALLOC: usize = 10_000;
56        const MIN_BYTES_PER_ALLOC: usize = 2000;
57        static_assertions::const_assert!(MIN_BYTES_PER_ALLOC <= MAX_BYTES_PER_ALLOC);
58
59        while len > 0 {
60            if self.unused.is_empty() {
61                // Allocate a new buffer with a size equal to the number of bytes to read, clamped
62                // to the range `[MIN_BYTES_PER_ALLOC, MAX_BYTES_PER_ALLOC]`. Any allocated bytes of
63                // the buffer that aren't used will be re-used the next time that this method is
64                // called. This allows us to avoid making many small allocations if the application
65                // sends only a small number of bytes at a time.
66                let next_alloc_size = len;
67                let next_alloc_size = std::cmp::min(next_alloc_size, MAX_BYTES_PER_ALLOC);
68                let next_alloc_size = std::cmp::max(next_alloc_size, MIN_BYTES_PER_ALLOC);
69                self.unused = BytesMut::zeroed(next_alloc_size);
70            }
71
72            // break off a piece of the `unused` buffer
73            let to_read = std::cmp::min(len, self.unused.len());
74            let mut chunk = self.unused.split_to(to_read);
75
76            // It would be nice if we could merge the segment with the previous data segment (if
77            // they are part of the same allocation), but `unsplit` (and `try_unsplit` in our fork)
78            // is only available for `BytesMut` and not `Bytes`. If it was available it would allow
79            // us to combine several small writes into a larger chunk, which would reduce the number
80            // of chunks we need to send in packets.
81
82            reader.read_exact(&mut chunk[..])?;
83            self.add_segment(Segment::Data(chunk.into()));
84
85            len -= to_read;
86        }
87
88        // If the `unused` buffer is empty, replace it with a new empty `BytesMut`. The old
89        // `BytesMut`, while empty, may still point to the old allocation and hold a reference to
90        // it, preventing it from being deallocated. We replace it with a new `BytesMut` that does
91        // not point to any allocation to make sure that the old allocation can be deallocated.
92        if self.unused.is_empty() {
93            self.unused = BytesMut::new();
94        }
95
96        Ok(())
97    }
98
99    fn add_segment(&mut self, seg: Segment) {
100        assert!(!self.fin_added);
101
102        if matches!(seg, Segment::Fin) {
103            self.fin_added = true;
104        }
105
106        if seg.len() == 0 {
107            return;
108        }
109
110        self.end_seq += seg.len();
111        self.segments.push_back(seg);
112    }
113
114    pub fn start_seq(&self) -> Seq {
115        self.start_seq
116    }
117
118    pub fn next_seq(&self) -> Seq {
119        self.end_seq
120    }
121
122    pub fn contains(&self, seq: Seq) -> bool {
123        SeqRange::new(self.start_seq, self.end_seq).contains(seq)
124    }
125
126    pub fn len(&self) -> u32 {
127        self.end_seq - self.start_seq
128    }
129
130    pub fn advance_start(&mut self, new_start: Seq) {
131        assert!(self.contains(new_start) || new_start == self.end_seq);
132
133        while self.start_seq != new_start {
134            let advance_by = new_start - self.start_seq;
135
136            // this shouldn't panic due to the assertion above
137            let front = self.segments.front_mut().unwrap();
138
139            // if the chunk would be completely removed
140            if front.len() <= advance_by {
141                self.start_seq += front.len();
142                self.segments.pop_front();
143                continue;
144            }
145
146            let Segment::Data(data) = front else {
147                // syn and fin segments have a length of only 1 byte, so they should have been
148                // popped by the check above
149                unreachable!();
150            };
151
152            // update the existing `Bytes` object rather than using `slice()` to avoid an atomic
153            // operation
154            data.advance(advance_by.try_into().unwrap());
155            assert!(!data.is_empty());
156
157            self.start_seq += advance_by;
158        }
159    }
160
161    /// Get the next segment that has not yet been transmitted. The `offset` argument can be used to
162    /// return the next segment starting at `offset` bytes from the next non-transmitted segment.
163    // TODO: this is slow and is called often
164    pub fn next_not_transmitted(&self, offset: u32) -> Option<(Seq, Segment)> {
165        // the sequence number of the segment we want to return
166        let target_seq = self.transmitted_up_to + offset;
167
168        // check if we've already transmitted everything in the buffer
169        if !self.contains(target_seq) {
170            return None;
171        }
172
173        let mut seq_cursor = self.start_seq;
174        for seg in &self.segments {
175            let len = seg.len();
176
177            // if the target sequence number is within this segment
178            if SeqRange::new(seq_cursor, seq_cursor + len).contains(target_seq) {
179                let new_segment = match seg {
180                    Segment::Syn => Segment::Syn,
181                    Segment::Fin => Segment::Fin,
182                    Segment::Data(chunk) => {
183                        // the target sequence number might be somewhere within this chunk, so we
184                        // need to trim any bytes with a lower sequence number
185                        let chunk_offset = target_seq - seq_cursor;
186                        let chunk_offset: usize = chunk_offset.try_into().unwrap();
187                        Segment::Data(chunk.slice(chunk_offset..))
188                    }
189                };
190
191                return Some((target_seq, new_segment));
192            }
193
194            seq_cursor += len;
195        }
196
197        // we confirmed above that the target sequence number is contained within the buffer, but we
198        // looped over all segments in the buffer and didn't find it
199        unreachable!();
200    }
201
202    pub fn mark_as_transmitted(&mut self, up_to: Seq, time: T) {
203        assert!(self.contains(up_to) || up_to == self.end_seq);
204
205        if up_to != self.transmitted_up_to {
206            self.time_last_segment_sent = Some(time);
207        }
208
209        self.transmitted_up_to = up_to;
210    }
211}
212
213#[derive(Debug)]
214pub(crate) struct RecvQueue {
215    segments: LinkedList<Bytes>,
216    // inclusive
217    start_seq: Seq,
218    // exclusive
219    end_seq: Seq,
220    syn_added: bool,
221    fin_added: bool,
222}
223
224impl RecvQueue {
225    pub fn new(initial_seq: Seq) -> Self {
226        Self {
227            segments: LinkedList::new(),
228            start_seq: initial_seq,
229            end_seq: initial_seq,
230            syn_added: false,
231            fin_added: false,
232        }
233    }
234
235    pub fn add_syn(&mut self) {
236        assert!(!self.syn_added);
237        self.syn_added = true;
238
239        self.start_seq += 1;
240        self.end_seq += 1;
241    }
242
243    pub fn add_fin(&mut self) {
244        assert!(self.syn_added);
245        assert!(!self.fin_added);
246        self.fin_added = true;
247
248        self.start_seq += 1;
249        self.end_seq += 1;
250    }
251
252    pub fn add(&mut self, data: Bytes) {
253        assert!(self.syn_added);
254        assert!(!self.fin_added);
255
256        let len: u32 = data.len().try_into().unwrap();
257
258        if len == 0 {
259            return;
260        }
261
262        self.end_seq += len;
263        self.segments.push_back(data);
264    }
265
266    pub fn syn_added(&self) -> bool {
267        self.syn_added
268    }
269
270    pub fn len(&self) -> u32 {
271        self.end_seq - self.start_seq
272    }
273
274    pub fn is_empty(&self) -> bool {
275        self.len() == 0
276    }
277
278    pub fn next_seq(&self) -> Seq {
279        self.end_seq
280    }
281
282    pub fn pop(&mut self, len: u32) -> Option<(Seq, Bytes)> {
283        let seq = self.start_seq;
284
285        let chunk_len: u32 = self.segments.front()?.len().try_into().unwrap();
286
287        let segment = if len < chunk_len {
288            // want fewer bytes than the size of the next chunk, so need to split the chunk
289            self.segments
290                .front_mut()
291                .unwrap()
292                .split_to(len.try_into().unwrap())
293        } else {
294            // want more bytes than the size of the next chunk, so return as much as we can in a
295            // single chunk
296            self.segments.pop_front().unwrap()
297        };
298
299        // only return an empty chunk if len was 0
300        assert!(!segment.is_empty() || len == 0);
301
302        let advance_by: u32 = segment.len().try_into().unwrap();
303        self.start_seq += advance_by;
304
305        Some((seq, segment))
306    }
307
308    pub fn read(&mut self, mut writer: impl Write, len: usize) -> Result<usize, std::io::Error> {
309        let mut bytes_copied = 0;
310
311        if self.is_empty() {
312            return Ok(0);
313        }
314
315        while bytes_copied < len {
316            let remaining = len - bytes_copied;
317            let remaining_u32 = remaining.try_into().unwrap_or(u32::MAX);
318
319            let Some((_seq, data)) = self.pop(remaining_u32) else {
320                // no more data available
321                break;
322            };
323
324            assert!(data.len() <= remaining);
325
326            // TODO: the stream will lose partial data if there's an error; is this fine?
327            writer.write_all(&data)?;
328
329            bytes_copied += data.len();
330        }
331
332        Ok(bytes_copied)
333    }
334}
335
336#[derive(Debug)]
337pub(crate) enum Segment {
338    Data(Bytes),
339    Syn,
340    Fin,
341}
342
343impl Segment {
344    pub fn len(&self) -> u32 {
345        match self {
346            Segment::Syn | Segment::Fin => 1,
347            Segment::Data(data) => data.len().try_into().unwrap(),
348        }
349    }
350}