1use std::collections::LinkedList;
2use std::io::{Read, Write};
3
4use bytes::{Buf, Bytes, BytesMut};
5
6use crate::seq::{Seq, SeqRange};
7use crate::util::time::Instant;
8
9#[derive(Debug)]
10pub(crate) struct SendQueue<T: Instant> {
11 segments: LinkedList<Segment>,
12 time_last_segment_sent: Option<T>,
13 transmitted_up_to: Seq,
15 start_seq: Seq,
17 end_seq: Seq,
19 fin_added: bool,
20 unused: BytesMut,
21}
22
23impl<T: Instant> SendQueue<T> {
24 pub fn new(initial_seq: Seq) -> Self {
25 let mut queue = Self {
26 segments: LinkedList::new(),
27 time_last_segment_sent: None,
28 transmitted_up_to: initial_seq,
29 start_seq: initial_seq,
30 end_seq: initial_seq,
31 fin_added: false,
32 unused: BytesMut::new(),
33 };
34
35 queue.add_syn();
36
37 queue
38 }
39
40 fn add_syn(&mut self) {
41 self.add_segment(Segment::Syn);
42 }
43
44 pub fn add_fin(&mut self) {
45 self.add_segment(Segment::Fin);
46 }
47
48 pub fn add_data(
49 &mut self,
50 mut reader: impl Read,
51 mut len: usize,
52 ) -> Result<(), std::io::Error> {
53 const MAX_BYTES_PER_ALLOC: usize = 10_000;
56 const MIN_BYTES_PER_ALLOC: usize = 2000;
57 static_assertions::const_assert!(MIN_BYTES_PER_ALLOC <= MAX_BYTES_PER_ALLOC);
58
59 while len > 0 {
60 if self.unused.is_empty() {
61 let next_alloc_size = len;
67 let next_alloc_size = std::cmp::min(next_alloc_size, MAX_BYTES_PER_ALLOC);
68 let next_alloc_size = std::cmp::max(next_alloc_size, MIN_BYTES_PER_ALLOC);
69 self.unused = BytesMut::zeroed(next_alloc_size);
70 }
71
72 let to_read = std::cmp::min(len, self.unused.len());
74 let mut chunk = self.unused.split_to(to_read);
75
76 reader.read_exact(&mut chunk[..])?;
83 self.add_segment(Segment::Data(chunk.into()));
84
85 len -= to_read;
86 }
87
88 if self.unused.is_empty() {
93 self.unused = BytesMut::new();
94 }
95
96 Ok(())
97 }
98
99 fn add_segment(&mut self, seg: Segment) {
100 assert!(!self.fin_added);
101
102 if matches!(seg, Segment::Fin) {
103 self.fin_added = true;
104 }
105
106 if seg.len() == 0 {
107 return;
108 }
109
110 self.end_seq += seg.len();
111 self.segments.push_back(seg);
112 }
113
114 pub fn start_seq(&self) -> Seq {
115 self.start_seq
116 }
117
118 pub fn next_seq(&self) -> Seq {
119 self.end_seq
120 }
121
122 pub fn contains(&self, seq: Seq) -> bool {
123 SeqRange::new(self.start_seq, self.end_seq).contains(seq)
124 }
125
126 pub fn len(&self) -> u32 {
127 self.end_seq - self.start_seq
128 }
129
130 pub fn advance_start(&mut self, new_start: Seq) {
131 assert!(self.contains(new_start) || new_start == self.end_seq);
132
133 while self.start_seq != new_start {
134 let advance_by = new_start - self.start_seq;
135
136 let front = self.segments.front_mut().unwrap();
138
139 if front.len() <= advance_by {
141 self.start_seq += front.len();
142 self.segments.pop_front();
143 continue;
144 }
145
146 let Segment::Data(data) = front else {
147 unreachable!();
150 };
151
152 data.advance(advance_by.try_into().unwrap());
155 assert!(!data.is_empty());
156
157 self.start_seq += advance_by;
158 }
159 }
160
161 pub fn next_not_transmitted(&self, offset: u32) -> Option<(Seq, Segment)> {
165 let target_seq = self.transmitted_up_to + offset;
167
168 if !self.contains(target_seq) {
170 return None;
171 }
172
173 let mut seq_cursor = self.start_seq;
174 for seg in &self.segments {
175 let len = seg.len();
176
177 if SeqRange::new(seq_cursor, seq_cursor + len).contains(target_seq) {
179 let new_segment = match seg {
180 Segment::Syn => Segment::Syn,
181 Segment::Fin => Segment::Fin,
182 Segment::Data(chunk) => {
183 let chunk_offset = target_seq - seq_cursor;
186 let chunk_offset: usize = chunk_offset.try_into().unwrap();
187 Segment::Data(chunk.slice(chunk_offset..))
188 }
189 };
190
191 return Some((target_seq, new_segment));
192 }
193
194 seq_cursor += len;
195 }
196
197 unreachable!();
200 }
201
202 pub fn mark_as_transmitted(&mut self, up_to: Seq, time: T) {
203 assert!(self.contains(up_to) || up_to == self.end_seq);
204
205 if up_to != self.transmitted_up_to {
206 self.time_last_segment_sent = Some(time);
207 }
208
209 self.transmitted_up_to = up_to;
210 }
211}
212
213#[derive(Debug)]
214pub(crate) struct RecvQueue {
215 segments: LinkedList<Bytes>,
216 start_seq: Seq,
218 end_seq: Seq,
220 syn_added: bool,
221 fin_added: bool,
222}
223
224impl RecvQueue {
225 pub fn new(initial_seq: Seq) -> Self {
226 Self {
227 segments: LinkedList::new(),
228 start_seq: initial_seq,
229 end_seq: initial_seq,
230 syn_added: false,
231 fin_added: false,
232 }
233 }
234
235 pub fn add_syn(&mut self) {
236 assert!(!self.syn_added);
237 self.syn_added = true;
238
239 self.start_seq += 1;
240 self.end_seq += 1;
241 }
242
243 pub fn add_fin(&mut self) {
244 assert!(self.syn_added);
245 assert!(!self.fin_added);
246 self.fin_added = true;
247
248 self.start_seq += 1;
249 self.end_seq += 1;
250 }
251
252 pub fn add(&mut self, data: Bytes) {
253 assert!(self.syn_added);
254 assert!(!self.fin_added);
255
256 let len: u32 = data.len().try_into().unwrap();
257
258 if len == 0 {
259 return;
260 }
261
262 self.end_seq += len;
263 self.segments.push_back(data);
264 }
265
266 pub fn syn_added(&self) -> bool {
267 self.syn_added
268 }
269
270 pub fn len(&self) -> u32 {
271 self.end_seq - self.start_seq
272 }
273
274 pub fn is_empty(&self) -> bool {
275 self.len() == 0
276 }
277
278 pub fn next_seq(&self) -> Seq {
279 self.end_seq
280 }
281
282 pub fn pop(&mut self, len: u32) -> Option<(Seq, Bytes)> {
283 let seq = self.start_seq;
284
285 let chunk_len: u32 = self.segments.front()?.len().try_into().unwrap();
286
287 let segment = if len < chunk_len {
288 self.segments
290 .front_mut()
291 .unwrap()
292 .split_to(len.try_into().unwrap())
293 } else {
294 self.segments.pop_front().unwrap()
297 };
298
299 assert!(!segment.is_empty() || len == 0);
301
302 let advance_by: u32 = segment.len().try_into().unwrap();
303 self.start_seq += advance_by;
304
305 Some((seq, segment))
306 }
307
308 pub fn read(&mut self, mut writer: impl Write, len: usize) -> Result<usize, std::io::Error> {
309 let mut bytes_copied = 0;
310
311 if self.is_empty() {
312 return Ok(0);
313 }
314
315 while bytes_copied < len {
316 let remaining = len - bytes_copied;
317 let remaining_u32 = remaining.try_into().unwrap_or(u32::MAX);
318
319 let Some((_seq, data)) = self.pop(remaining_u32) else {
320 break;
322 };
323
324 assert!(data.len() <= remaining);
325
326 writer.write_all(&data)?;
328
329 bytes_copied += data.len();
330 }
331
332 Ok(bytes_copied)
333 }
334}
335
336#[derive(Debug)]
337pub(crate) enum Segment {
338 Data(Bytes),
339 Syn,
340 Fin,
341}
342
343impl Segment {
344 pub fn len(&self) -> u32 {
345 match self {
346 Segment::Syn | Segment::Fin => 1,
347 Segment::Data(data) => data.len().try_into().unwrap(),
348 }
349 }
350}