1use std::collections::LinkedList;
7use std::io::{ErrorKind, Read, Write};
8
9use bytes::{Bytes, BytesMut};
10
11pub struct ByteQueue {
22    bytes: LinkedList<ByteChunk>,
24    unused_buffer: Option<BytesMut>,
26    length: usize,
28    default_chunk_capacity: usize,
30    #[cfg(test)]
31    total_allocations: u64,
33}
34
35impl ByteQueue {
36    pub fn new(default_chunk_capacity: usize) -> Self {
37        Self {
38            bytes: LinkedList::new(),
39            unused_buffer: None,
40            length: 0,
41            default_chunk_capacity,
42            #[cfg(test)]
43            total_allocations: 0,
44        }
45    }
46
47    pub fn num_bytes(&self) -> usize {
50        self.length
51    }
52
53    pub fn has_bytes(&self) -> bool {
55        self.num_bytes() > 0
56    }
57
58    pub fn has_chunks(&self) -> bool {
60        !self.bytes.is_empty()
61    }
62
63    #[must_use]
64    fn alloc_zeroed_buffer(&mut self, size: usize) -> BytesMut {
65        #[cfg(test)]
66        {
67            self.total_allocations += 1;
68        }
69
70        BytesMut::from_iter(std::iter::repeat_n(0, size))
71    }
72
73    pub fn push_stream<R: Read>(&mut self, mut src: R) -> std::io::Result<usize> {
75        let mut total_copied = 0;
76
77        loop {
78            let mut unused = match self.unused_buffer.take() {
79                Some(x) => x,
81                None => self.alloc_zeroed_buffer(self.default_chunk_capacity),
83            };
84            assert_eq!(unused.len(), unused.capacity());
85
86            let copied = src.read(&mut unused)?;
87            let bytes = unused.split_to(copied);
88
89            total_copied += bytes.len();
90
91            if !unused.is_empty() {
92                self.unused_buffer = Some(unused);
94            }
95
96            if bytes.is_empty() {
97                break;
98            }
99
100            let mut bytes = Some(bytes);
101
102            if let Some(last_chunk) = self.bytes.back_mut() {
104                if last_chunk.chunk_type == ChunkType::Stream {
106                    if let BytesWrapper::Mutable(last_chunk) = &mut last_chunk.data {
108                        let len = bytes.as_ref().unwrap().len();
109                        bytes = last_chunk.try_unsplit(bytes.take().unwrap()).err();
113                        if bytes.is_none() {
114                            self.length += len;
116                        }
117                    }
118                }
119            }
120
121            if let Some(bytes) = bytes {
123                self.push_chunk(bytes, ChunkType::Stream);
124            }
125        }
126
127        Ok(total_copied)
128    }
129
130    pub fn push_packet<R: Read>(&mut self, mut src: R, size: usize) -> std::io::Result<()> {
133        let unused = match &mut self.unused_buffer {
134            Some(buf) if buf.len() >= size => buf,
136            _ => &mut self.alloc_zeroed_buffer(size),
138        };
139        assert_eq!(unused.len(), unused.capacity());
140
141        src.read_exact(&mut unused[..size])?;
142        let bytes = unused.split_to(size);
143
144        if let Some(ref unused_buffer) = self.unused_buffer
146            && unused_buffer.is_empty()
147        {
148            self.unused_buffer = None;
149        }
150
151        self.push_chunk(bytes, ChunkType::Packet);
152
153        Ok(())
154    }
155
156    pub fn push_chunk(&mut self, data: impl Into<BytesWrapper>, chunk_type: ChunkType) -> usize {
158        let data = data.into();
159        let len = data.len();
160        self.length += len;
161        self.bytes.push_back(ByteChunk::new(data, chunk_type));
162        len
163    }
164
165    pub fn pop<W: Write>(&mut self, dst: W) -> std::io::Result<Option<(usize, usize, ChunkType)>> {
172        match self.bytes.front() {
174            Some(x) => match x.chunk_type {
175                ChunkType::Stream => {
176                    let num_copied = self.pop_stream(dst)?;
177                    Ok(Some((num_copied, num_copied, ChunkType::Stream)))
178                }
179                ChunkType::Packet => {
180                    let (num_copied, num_removed_from_buf) = self.pop_packet(dst)?;
181                    Ok(Some((num_copied, num_removed_from_buf, ChunkType::Packet)))
182                }
183            },
184            None => Ok(None),
185        }
186    }
187
188    fn pop_stream<W: Write>(&mut self, mut dst: W) -> std::io::Result<usize> {
189        let mut total_copied = 0;
190        assert_ne!(
191            self.bytes.len(),
192            0,
193            "This function assumes there is a chunk"
194        );
195
196        loop {
197            let bytes = match self.bytes.front_mut() {
198                Some(x) if x.chunk_type != ChunkType::Stream => break,
199                Some(x) => &mut x.data,
200                None => break,
201            };
202
203            let copied = match dst.write(bytes.as_ref()) {
204                Ok(x) => x,
205                Err(e) if e.kind() == ErrorKind::Interrupted => continue,
207                Err(e) if e.kind() == ErrorKind::WouldBlock => {
208                    if total_copied == 0 {
210                        return Err(e);
211                    }
212                    0
214                }
215                Err(e) => return Err(e),
217            };
218
219            let _ = bytes.split_to(copied);
220
221            if copied == 0 {
222                break;
223            }
224
225            self.length -= copied;
226            total_copied += copied;
227
228            if bytes.is_empty() {
229                self.bytes.pop_front();
230            }
231        }
232
233        Ok(total_copied)
234    }
235
236    fn pop_packet<W: Write>(&mut self, mut dst: W) -> std::io::Result<(usize, usize)> {
237        let mut chunk = self
238            .bytes
239            .pop_front()
240            .expect("This function assumes there is a chunk");
241        assert_eq!(chunk.chunk_type, ChunkType::Packet);
242        let bytes = &mut chunk.data;
243
244        let packet_len = bytes.len();
245
246        self.length = self.length.checked_sub(packet_len).unwrap();
248
249        let mut total_copied = 0;
250
251        loop {
252            let copied = match dst.write(bytes.as_ref()) {
253                Ok(x) => x,
254                Err(e) if e.kind() == ErrorKind::Interrupted => continue,
256                Err(e) if e.kind() == ErrorKind::WouldBlock => {
259                    panic!("Non-blocking writers aren't supported for packets")
260                }
261                Err(e) => return Err(e),
264            };
265
266            let _ = bytes.split_to(copied);
267
268            if copied == 0 {
269                break;
270            }
271
272            total_copied += copied;
273        }
274
275        Ok((total_copied, packet_len))
276    }
277
278    pub fn pop_chunk(&mut self, size_hint: usize) -> Option<(Bytes, ChunkType)> {
282        let chunk = self.bytes.front_mut()?;
283        let chunk_type = chunk.chunk_type;
284
285        let bytes = match chunk_type {
286            ChunkType::Stream => {
287                let temp = chunk
288                    .data
289                    .split_to(std::cmp::min(chunk.data.len(), size_hint));
290                if chunk.data.is_empty() {
291                    self.bytes.pop_front();
292                }
293                temp
294            }
295            ChunkType::Packet => self.bytes.pop_front().unwrap().data,
296        };
297
298        self.length -= bytes.len();
299
300        Some((bytes.into(), chunk_type))
301    }
302
303    pub fn peek<W: Write>(&self, dst: W) -> std::io::Result<Option<(usize, usize, ChunkType)>> {
309        match self.bytes.front() {
311            Some(x) => match x.chunk_type {
312                ChunkType::Stream => {
313                    let num_copied = self.peek_stream(dst)?;
314                    Ok(Some((num_copied, num_copied, ChunkType::Stream)))
315                }
316                ChunkType::Packet => {
317                    let (num_copied, size_of_packet) = self.peek_packet(dst)?;
318                    Ok(Some((num_copied, size_of_packet, ChunkType::Packet)))
319                }
320            },
321            None => Ok(None),
322        }
323    }
324
325    fn peek_stream<W: Write>(&self, mut dst: W) -> std::io::Result<usize> {
326        let mut total_copied = 0;
327        assert_ne!(
328            self.bytes.len(),
329            0,
330            "This function assumes there is a chunk"
331        );
332
333        for bytes in self.bytes.iter() {
334            let mut bytes = match bytes {
335                x if x.chunk_type != ChunkType::Stream => break,
336                x => x.data.as_ref(),
337            };
338
339            loop {
340                let copied = match dst.write(bytes) {
341                    Ok(x) => x,
342                    Err(e) if e.kind() == ErrorKind::Interrupted => continue,
344                    Err(e) if e.kind() == ErrorKind::WouldBlock => {
345                        if total_copied == 0 {
347                            return Err(e);
348                        }
349                        0
351                    }
352                    Err(e) => return Err(e),
354                };
355
356                bytes = &bytes[copied..];
357
358                if copied == 0 {
359                    break;
360                }
361
362                total_copied += copied;
363            }
364        }
365
366        Ok(total_copied)
367    }
368
369    fn peek_packet<W: Write>(&self, mut dst: W) -> std::io::Result<(usize, usize)> {
370        let chunk = self
371            .bytes
372            .front()
373            .expect("This function assumes there is a chunk");
374
375        assert_eq!(chunk.chunk_type, ChunkType::Packet);
376        let mut bytes = chunk.data.as_ref();
377        let packet_len = bytes.len();
378        let mut total_copied = 0;
379
380        loop {
381            let copied = match dst.write(bytes) {
382                Ok(x) => x,
383                Err(e) if e.kind() == ErrorKind::Interrupted => continue,
385                Err(e) if e.kind() == ErrorKind::WouldBlock => {
388                    panic!("Non-blocking writers aren't supported for packets")
389                }
390                Err(e) => return Err(e),
393            };
394
395            bytes = &bytes[copied..];
396
397            if copied == 0 {
398                break;
399            }
400
401            total_copied += copied;
402        }
403
404        Ok((total_copied, packet_len))
405    }
406}
407
408#[cfg(debug_assertions)]
410impl std::ops::Drop for ByteQueue {
411    fn drop(&mut self) {
412        assert_eq!(
414            self.num_bytes(),
415            self.bytes.iter().map(|x| x.data.len()).sum::<usize>()
416        );
417    }
418}
419
420#[derive(Copy, Clone, Debug, PartialEq, Eq)]
422pub enum ChunkType {
423    Stream,
424    Packet,
425}
426
427pub enum BytesWrapper {
429    Mutable(BytesMut),
430    Immutable(Bytes),
431}
432
433impl From<BytesMut> for BytesWrapper {
434    fn from(x: BytesMut) -> Self {
435        BytesWrapper::Mutable(x)
436    }
437}
438
439impl From<Bytes> for BytesWrapper {
440    fn from(x: Bytes) -> Self {
441        BytesWrapper::Immutable(x)
442    }
443}
444
445impl From<BytesWrapper> for Bytes {
446    fn from(x: BytesWrapper) -> Self {
447        match x {
448            BytesWrapper::Mutable(x) => x.freeze(),
449            BytesWrapper::Immutable(x) => x,
450        }
451    }
452}
453
454impl std::convert::AsRef<[u8]> for BytesWrapper {
455    fn as_ref(&self) -> &[u8] {
456        match self {
457            BytesWrapper::Mutable(x) => x,
458            BytesWrapper::Immutable(x) => x,
459        }
460    }
461}
462
463impl std::borrow::Borrow<[u8]> for BytesWrapper {
464    fn borrow(&self) -> &[u8] {
465        self.as_ref()
466    }
467}
468
469impl BytesWrapper {
470    enum_passthrough!(self, (), Mutable, Immutable;
471        pub fn len(&self) -> usize
472    );
473    enum_passthrough!(self, (), Mutable, Immutable;
474        pub fn is_empty(&self) -> bool
475    );
476    enum_passthrough_into!(self, (at), Mutable, Immutable;
477        pub fn split_to(&mut self, at: usize) -> BytesWrapper
478    );
479}
480
481struct ByteChunk {
483    data: BytesWrapper,
484    chunk_type: ChunkType,
485}
486
487impl ByteChunk {
488    pub fn new(data: BytesWrapper, chunk_type: ChunkType) -> Self {
489        Self { data, chunk_type }
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use rand::{Rng, RngCore};
496    use rand_chacha::ChaCha20Rng;
497    use rand_core::SeedableRng;
498
499    use super::*;
500
501    #[test]
502    fn test_bytequeue_stream() {
503        let chunk_size = 5;
504        let mut bq = ByteQueue::new(chunk_size);
505
506        let src1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13];
507        let src2 = [51, 52, 53];
508        let mut dst1 = [0; 8];
509        let mut dst2 = [0; 10];
510
511        bq.push_stream(&src1[..]).unwrap();
512        bq.push_stream(&[][..]).unwrap();
513        bq.push_stream(&src2[..]).unwrap();
514
515        assert_eq!(bq.num_bytes(), src1.len() + src2.len());
518        assert_eq!(
520            bq.bytes.len(),
521            (src1.len() + src2.len() - 1) / chunk_size + 1
522        );
523        assert_eq!(bq.total_allocations as usize, bq.bytes.len());
524
525        assert_eq!(8, bq.peek(&mut dst1[..]).unwrap().unwrap().0);
528        assert_eq!(10, bq.peek(&mut dst2[..]).unwrap().unwrap().0);
529
530        assert_eq!(dst1, [1, 2, 3, 4, 5, 6, 7, 8]);
531        assert_eq!(dst2, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
532        assert_eq!(bq.num_bytes(), src1.len() + src2.len());
533
534        dst1.fill(0);
537        dst2.fill(0);
538
539        assert_eq!(8, bq.pop(&mut dst1[..]).unwrap().unwrap().0);
540        assert_eq!(8, bq.pop(&mut dst2[..]).unwrap().unwrap().0);
541
542        assert_eq!(dst1, [1, 2, 3, 4, 5, 6, 7, 8]);
543        assert_eq!(dst2, [9, 10, 11, 12, 13, 51, 52, 53, 0, 0]);
544        assert_eq!(bq.num_bytes(), 0);
545    }
546
547    #[test]
548    fn test_bytequeue_packet() {
549        let mut bq = ByteQueue::new(5);
550
551        let src1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13];
552        let src2 = [51, 52, 53];
553        let mut dst1 = [0; 8];
554        let mut dst2 = [0; 10];
555
556        bq.push_packet(&src1[..], src1.len()).unwrap();
557        bq.push_packet(&[][..], 0).unwrap();
558        bq.push_packet(&src2[..], src2.len()).unwrap();
559
560        assert_eq!(bq.num_bytes(), src1.len() + src2.len());
563        assert_eq!(bq.bytes.len(), 3);
564        assert_eq!(bq.total_allocations, 3);
565
566        assert_eq!(8, bq.peek(&mut dst1[..]).unwrap().unwrap().0);
569        assert_eq!(10, bq.peek(&mut dst2[..]).unwrap().unwrap().0);
570        assert_eq!(10, bq.peek(&mut dst2[..]).unwrap().unwrap().0);
571
572        assert_eq!(dst1, [1, 2, 3, 4, 5, 6, 7, 8]);
573        assert_eq!(dst2, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
574        assert_eq!(bq.num_bytes(), src1.len() + src2.len());
575
576        dst1.fill(0);
579        dst2.fill(0);
580
581        assert_eq!(8, bq.pop(&mut dst1[..]).unwrap().unwrap().0);
582        assert_eq!(0, bq.pop(&mut dst2[..]).unwrap().unwrap().0);
583        assert_eq!(3, bq.pop(&mut dst2[..]).unwrap().unwrap().0);
584
585        assert_eq!(dst1, [1, 2, 3, 4, 5, 6, 7, 8]);
586        assert_eq!(dst2, [51, 52, 53, 0, 0, 0, 0, 0, 0, 0]);
587        assert_eq!(bq.num_bytes(), 0);
588    }
589
590    #[test]
591    fn test_bytequeue_combined_1() {
592        let mut bq = ByteQueue::new(10);
593
594        bq.push_stream(&[1, 2, 3][..]).unwrap();
595        bq.push_packet(&[4, 5, 6][..], 3).unwrap();
596        bq.push_stream(&[7, 8, 9][..]).unwrap();
597
598        assert_eq!(bq.num_bytes(), 9);
599        assert_eq!(bq.bytes.len(), 3);
600        assert_eq!(bq.total_allocations, 1);
601
602        let mut buf = [0; 20];
603
604        assert_eq!(
605            bq.pop(&mut buf[..]).unwrap(),
606            Some((3, 3, ChunkType::Stream))
607        );
608        assert_eq!(buf[..3], [1, 2, 3]);
609
610        assert_eq!(
611            bq.pop(&mut buf[..]).unwrap(),
612            Some((3, 3, ChunkType::Packet))
613        );
614        assert_eq!(buf[..3], [4, 5, 6]);
615
616        assert_eq!(
617            bq.pop(&mut buf[..]).unwrap(),
618            Some((3, 3, ChunkType::Stream))
619        );
620        assert_eq!(buf[..3], [7, 8, 9]);
621
622        assert!(!bq.has_bytes());
623    }
624
625    #[test]
626    fn test_bytequeue_combined_2() {
627        let mut bq = ByteQueue::new(5);
628
629        bq.push_stream(&[1, 2, 3, 4][..]).unwrap();
630        bq.push_stream(&[5][..]).unwrap();
631        bq.push_stream(&[6][..]).unwrap();
632        bq.push_packet(&[7, 8, 9, 10, 11, 12, 13, 14][..], 8)
633            .unwrap();
634        bq.push_stream(&[15, 16, 17][..]).unwrap();
635        bq.push_chunk(
636            Bytes::from_static(&[100, 101, 102, 103, 104, 105]),
637            ChunkType::Packet,
638        );
639        bq.push_packet(&[][..], 0).unwrap();
640        bq.push_stream(&[18][..]).unwrap();
641        bq.push_stream(&[19][..]).unwrap();
642        bq.push_stream(&[20, 21][..]).unwrap();
643
644        let mut buf = [0; 20];
645
646        assert_eq!(
647            bq.pop(&mut buf[..3]).unwrap(),
648            Some((3, 3, ChunkType::Stream))
649        );
650        assert_eq!(buf[..3], [1, 2, 3]);
651
652        assert_eq!(
653            bq.pop(&mut buf[..5]).unwrap(),
654            Some((3, 3, ChunkType::Stream))
655        );
656        assert_eq!(buf[..3], [4, 5, 6]);
657
658        assert_eq!(
659            bq.pop(&mut buf[..4]).unwrap(),
660            Some((4, 8, ChunkType::Packet))
661        );
662        assert_eq!(buf[..4], [7, 8, 9, 10]);
663
664        assert_eq!(
665            bq.pop(&mut buf[..4]).unwrap(),
666            Some((3, 3, ChunkType::Stream))
667        );
668        assert_eq!(buf[..3], [15, 16, 17]);
669
670        assert_eq!(
671            bq.pop(&mut buf[..4]).unwrap(),
672            Some((4, 6, ChunkType::Packet))
673        );
674        assert_eq!(buf[..4], [100, 101, 102, 103]);
675
676        assert_eq!(
677            bq.pop(&mut buf[..4]).unwrap(),
678            Some((0, 0, ChunkType::Packet))
679        );
680
681        assert_eq!(bq.pop_chunk(4), Some(([18][..].into(), ChunkType::Stream)));
682
683        assert_eq!(
684            bq.pop_chunk(4),
685            Some(([19, 20, 21][..].into(), ChunkType::Stream))
686        );
687
688        assert_eq!(bq.pop_chunk(8), None);
689        assert_eq!(bq.pop(&mut buf[..4]).unwrap(), None);
690        assert!(!bq.has_bytes());
691    }
692
693    #[test]
694    fn test_bytequeue_fallible_writer() {
695        struct TestWriter;
696
697        impl std::io::Write for TestWriter {
698            fn write(&mut self, _buf: &[u8]) -> std::io::Result<usize> {
699                Err(std::io::ErrorKind::BrokenPipe.into())
700            }
701            fn flush(&mut self) -> std::io::Result<()> {
702                Ok(())
703            }
704        }
705
706        let mut bq = ByteQueue::new(10);
707
708        bq.push_packet(&[4, 5, 6][..], 3).unwrap();
709        bq.push_stream(&[1, 2, 3][..]).unwrap();
710
711        let mut writer = TestWriter {};
712
713        bq.pop(&mut writer).unwrap_err();
715        bq.pop(&mut writer).unwrap_err();
717
718        assert_eq!(bq.num_bytes(), 3);
719    }
720
721    #[test]
723    fn test_bytequeue_peek() {
724        let mut rng = ChaCha20Rng::seed_from_u64(1234);
725
726        const PROB_PUSH: f64 = 0.8;
727        const PROB_POP: f64 = 0.9;
728        const PROB_STREAM: f64 = 0.5;
729        const MAX_PUSH: usize = 20;
730        const MAX_POP: usize = 30;
731
732        #[cfg(not(miri))]
734        const NUM_ITER: usize = 5000;
735        #[cfg(miri)]
736        const NUM_ITER: usize = 10;
737
738        static_assertions::const_assert!(PROB_POP > PROB_PUSH);
740        static_assertions::const_assert!(MAX_POP > MAX_PUSH);
741
742        let mut bq = ByteQueue::new(10);
743
744        for _ in 0..NUM_ITER {
745            if rng.random_bool(PROB_PUSH) {
747                let mut bytes = vec![0u8; rng.random_range(0..MAX_PUSH)];
748                rng.fill_bytes(&mut bytes);
749
750                if rng.random_bool(PROB_STREAM) {
751                    bq.push_stream(&bytes[..]).unwrap();
752                } else {
753                    bq.push_packet(&bytes[..], bytes.len()).unwrap();
754                }
755            }
756
757            let pop_size = rng.random_range(0..MAX_POP);
758
759            let mut peeked_bytes = vec![0u8; pop_size];
761            let peek_rv = bq.peek(&mut peeked_bytes[..]).unwrap();
762
763            if rng.random_bool(PROB_POP) {
765                let mut popped_bytes = vec![0u8; pop_size];
766                let pop_rv = bq.pop(&mut popped_bytes[..]).unwrap();
767
768                assert_eq!(peek_rv, pop_rv);
769                assert_eq!(popped_bytes, peeked_bytes);
770            }
771        }
772    }
773}