1use std::collections::LinkedList;
7use std::io::{ErrorKind, Read, Write};
8
9use bytes::{Bytes, BytesMut};
10
11pub struct ByteQueue {
22 bytes: LinkedList<ByteChunk>,
24 unused_buffer: Option<BytesMut>,
26 length: usize,
28 default_chunk_capacity: usize,
30 #[cfg(test)]
31 total_allocations: u64,
33}
34
35impl ByteQueue {
36 pub fn new(default_chunk_capacity: usize) -> Self {
37 Self {
38 bytes: LinkedList::new(),
39 unused_buffer: None,
40 length: 0,
41 default_chunk_capacity,
42 #[cfg(test)]
43 total_allocations: 0,
44 }
45 }
46
47 pub fn num_bytes(&self) -> usize {
50 self.length
51 }
52
53 pub fn has_bytes(&self) -> bool {
55 self.num_bytes() > 0
56 }
57
58 pub fn has_chunks(&self) -> bool {
60 !self.bytes.is_empty()
61 }
62
63 #[must_use]
64 fn alloc_zeroed_buffer(&mut self, size: usize) -> BytesMut {
65 #[cfg(test)]
66 {
67 self.total_allocations += 1;
68 }
69
70 BytesMut::from_iter(std::iter::repeat_n(0, size))
71 }
72
73 pub fn push_stream<R: Read>(&mut self, mut src: R) -> std::io::Result<usize> {
75 let mut total_copied = 0;
76
77 loop {
78 let mut unused = match self.unused_buffer.take() {
79 Some(x) => x,
81 None => self.alloc_zeroed_buffer(self.default_chunk_capacity),
83 };
84 assert_eq!(unused.len(), unused.capacity());
85
86 let copied = src.read(&mut unused)?;
87 let bytes = unused.split_to(copied);
88
89 total_copied += bytes.len();
90
91 if !unused.is_empty() {
92 self.unused_buffer = Some(unused);
94 }
95
96 if bytes.is_empty() {
97 break;
98 }
99
100 let mut bytes = Some(bytes);
101
102 if let Some(last_chunk) = self.bytes.back_mut() {
104 if last_chunk.chunk_type == ChunkType::Stream {
106 if let BytesWrapper::Mutable(last_chunk) = &mut last_chunk.data {
108 let len = bytes.as_ref().unwrap().len();
109 bytes = last_chunk.try_unsplit(bytes.take().unwrap()).err();
113 if bytes.is_none() {
114 self.length += len;
116 }
117 }
118 }
119 }
120
121 if let Some(bytes) = bytes {
123 self.push_chunk(bytes, ChunkType::Stream);
124 }
125 }
126
127 Ok(total_copied)
128 }
129
130 pub fn push_packet<R: Read>(&mut self, mut src: R, size: usize) -> std::io::Result<()> {
133 let unused = match &mut self.unused_buffer {
134 Some(buf) if buf.len() >= size => buf,
136 _ => &mut self.alloc_zeroed_buffer(size),
138 };
139 assert_eq!(unused.len(), unused.capacity());
140
141 src.read_exact(&mut unused[..size])?;
142 let bytes = unused.split_to(size);
143
144 if let Some(ref unused_buffer) = self.unused_buffer {
146 if unused_buffer.is_empty() {
147 self.unused_buffer = None;
148 }
149 }
150
151 self.push_chunk(bytes, ChunkType::Packet);
152
153 Ok(())
154 }
155
156 pub fn push_chunk(&mut self, data: impl Into<BytesWrapper>, chunk_type: ChunkType) -> usize {
158 let data = data.into();
159 let len = data.len();
160 self.length += len;
161 self.bytes.push_back(ByteChunk::new(data, chunk_type));
162 len
163 }
164
165 pub fn pop<W: Write>(&mut self, dst: W) -> std::io::Result<Option<(usize, usize, ChunkType)>> {
172 match self.bytes.front() {
174 Some(x) => match x.chunk_type {
175 ChunkType::Stream => {
176 let num_copied = self.pop_stream(dst)?;
177 Ok(Some((num_copied, num_copied, ChunkType::Stream)))
178 }
179 ChunkType::Packet => {
180 let (num_copied, num_removed_from_buf) = self.pop_packet(dst)?;
181 Ok(Some((num_copied, num_removed_from_buf, ChunkType::Packet)))
182 }
183 },
184 None => Ok(None),
185 }
186 }
187
188 fn pop_stream<W: Write>(&mut self, mut dst: W) -> std::io::Result<usize> {
189 let mut total_copied = 0;
190 assert_ne!(
191 self.bytes.len(),
192 0,
193 "This function assumes there is a chunk"
194 );
195
196 loop {
197 let bytes = match self.bytes.front_mut() {
198 Some(x) if x.chunk_type != ChunkType::Stream => break,
199 Some(x) => &mut x.data,
200 None => break,
201 };
202
203 let copied = match dst.write(bytes.as_ref()) {
204 Ok(x) => x,
205 Err(e) if e.kind() == ErrorKind::Interrupted => continue,
207 Err(e) if e.kind() == ErrorKind::WouldBlock => {
208 if total_copied == 0 {
210 return Err(e);
211 }
212 0
214 }
215 Err(e) => return Err(e),
217 };
218
219 let _ = bytes.split_to(copied);
220
221 if copied == 0 {
222 break;
223 }
224
225 self.length -= copied;
226 total_copied += copied;
227
228 if bytes.is_empty() {
229 self.bytes.pop_front();
230 }
231 }
232
233 Ok(total_copied)
234 }
235
236 fn pop_packet<W: Write>(&mut self, mut dst: W) -> std::io::Result<(usize, usize)> {
237 let mut chunk = self
238 .bytes
239 .pop_front()
240 .expect("This function assumes there is a chunk");
241 assert_eq!(chunk.chunk_type, ChunkType::Packet);
242 let bytes = &mut chunk.data;
243
244 let packet_len = bytes.len();
245
246 self.length = self.length.checked_sub(packet_len).unwrap();
248
249 let mut total_copied = 0;
250
251 loop {
252 let copied = match dst.write(bytes.as_ref()) {
253 Ok(x) => x,
254 Err(e) if e.kind() == ErrorKind::Interrupted => continue,
256 Err(e) if e.kind() == ErrorKind::WouldBlock => {
259 panic!("Non-blocking writers aren't supported for packets")
260 }
261 Err(e) => return Err(e),
264 };
265
266 let _ = bytes.split_to(copied);
267
268 if copied == 0 {
269 break;
270 }
271
272 total_copied += copied;
273 }
274
275 Ok((total_copied, packet_len))
276 }
277
278 pub fn pop_chunk(&mut self, size_hint: usize) -> Option<(Bytes, ChunkType)> {
282 let chunk = self.bytes.front_mut()?;
283 let chunk_type = chunk.chunk_type;
284
285 let bytes = match chunk_type {
286 ChunkType::Stream => {
287 let temp = chunk
288 .data
289 .split_to(std::cmp::min(chunk.data.len(), size_hint));
290 if chunk.data.is_empty() {
291 self.bytes.pop_front();
292 }
293 temp
294 }
295 ChunkType::Packet => self.bytes.pop_front().unwrap().data,
296 };
297
298 self.length -= bytes.len();
299
300 Some((bytes.into(), chunk_type))
301 }
302
303 pub fn peek<W: Write>(&self, dst: W) -> std::io::Result<Option<(usize, usize, ChunkType)>> {
309 match self.bytes.front() {
311 Some(x) => match x.chunk_type {
312 ChunkType::Stream => {
313 let num_copied = self.peek_stream(dst)?;
314 Ok(Some((num_copied, num_copied, ChunkType::Stream)))
315 }
316 ChunkType::Packet => {
317 let (num_copied, size_of_packet) = self.peek_packet(dst)?;
318 Ok(Some((num_copied, size_of_packet, ChunkType::Packet)))
319 }
320 },
321 None => Ok(None),
322 }
323 }
324
325 fn peek_stream<W: Write>(&self, mut dst: W) -> std::io::Result<usize> {
326 let mut total_copied = 0;
327 assert_ne!(
328 self.bytes.len(),
329 0,
330 "This function assumes there is a chunk"
331 );
332
333 for bytes in self.bytes.iter() {
334 let mut bytes = match bytes {
335 x if x.chunk_type != ChunkType::Stream => break,
336 x => x.data.as_ref(),
337 };
338
339 loop {
340 let copied = match dst.write(bytes) {
341 Ok(x) => x,
342 Err(e) if e.kind() == ErrorKind::Interrupted => continue,
344 Err(e) if e.kind() == ErrorKind::WouldBlock => {
345 if total_copied == 0 {
347 return Err(e);
348 }
349 0
351 }
352 Err(e) => return Err(e),
354 };
355
356 bytes = &bytes[copied..];
357
358 if copied == 0 {
359 break;
360 }
361
362 total_copied += copied;
363 }
364 }
365
366 Ok(total_copied)
367 }
368
369 fn peek_packet<W: Write>(&self, mut dst: W) -> std::io::Result<(usize, usize)> {
370 let chunk = self
371 .bytes
372 .front()
373 .expect("This function assumes there is a chunk");
374
375 assert_eq!(chunk.chunk_type, ChunkType::Packet);
376 let mut bytes = chunk.data.as_ref();
377 let packet_len = bytes.len();
378 let mut total_copied = 0;
379
380 loop {
381 let copied = match dst.write(bytes) {
382 Ok(x) => x,
383 Err(e) if e.kind() == ErrorKind::Interrupted => continue,
385 Err(e) if e.kind() == ErrorKind::WouldBlock => {
388 panic!("Non-blocking writers aren't supported for packets")
389 }
390 Err(e) => return Err(e),
393 };
394
395 bytes = &bytes[copied..];
396
397 if copied == 0 {
398 break;
399 }
400
401 total_copied += copied;
402 }
403
404 Ok((total_copied, packet_len))
405 }
406}
407
408#[cfg(debug_assertions)]
410impl std::ops::Drop for ByteQueue {
411 fn drop(&mut self) {
412 assert_eq!(
414 self.num_bytes(),
415 self.bytes.iter().map(|x| x.data.len()).sum::<usize>()
416 );
417 }
418}
419
420#[derive(Copy, Clone, Debug, PartialEq, Eq)]
422pub enum ChunkType {
423 Stream,
424 Packet,
425}
426
427pub enum BytesWrapper {
429 Mutable(BytesMut),
430 Immutable(Bytes),
431}
432
433impl From<BytesMut> for BytesWrapper {
434 fn from(x: BytesMut) -> Self {
435 BytesWrapper::Mutable(x)
436 }
437}
438
439impl From<Bytes> for BytesWrapper {
440 fn from(x: Bytes) -> Self {
441 BytesWrapper::Immutable(x)
442 }
443}
444
445impl From<BytesWrapper> for Bytes {
446 fn from(x: BytesWrapper) -> Self {
447 match x {
448 BytesWrapper::Mutable(x) => x.freeze(),
449 BytesWrapper::Immutable(x) => x,
450 }
451 }
452}
453
454impl std::convert::AsRef<[u8]> for BytesWrapper {
455 fn as_ref(&self) -> &[u8] {
456 match self {
457 BytesWrapper::Mutable(x) => x,
458 BytesWrapper::Immutable(x) => x,
459 }
460 }
461}
462
463impl std::borrow::Borrow<[u8]> for BytesWrapper {
464 fn borrow(&self) -> &[u8] {
465 self.as_ref()
466 }
467}
468
469impl BytesWrapper {
470 enum_passthrough!(self, (), Mutable, Immutable;
471 pub fn len(&self) -> usize
472 );
473 enum_passthrough!(self, (), Mutable, Immutable;
474 pub fn is_empty(&self) -> bool
475 );
476 enum_passthrough_into!(self, (at), Mutable, Immutable;
477 pub fn split_to(&mut self, at: usize) -> BytesWrapper
478 );
479}
480
481struct ByteChunk {
483 data: BytesWrapper,
484 chunk_type: ChunkType,
485}
486
487impl ByteChunk {
488 pub fn new(data: BytesWrapper, chunk_type: ChunkType) -> Self {
489 Self { data, chunk_type }
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use rand::{Rng, RngCore};
496 use rand_chacha::ChaCha20Rng;
497 use rand_core::SeedableRng;
498
499 use super::*;
500
501 #[test]
502 fn test_bytequeue_stream() {
503 let chunk_size = 5;
504 let mut bq = ByteQueue::new(chunk_size);
505
506 let src1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13];
507 let src2 = [51, 52, 53];
508 let mut dst1 = [0; 8];
509 let mut dst2 = [0; 10];
510
511 bq.push_stream(&src1[..]).unwrap();
512 bq.push_stream(&[][..]).unwrap();
513 bq.push_stream(&src2[..]).unwrap();
514
515 assert_eq!(bq.num_bytes(), src1.len() + src2.len());
518 assert_eq!(
520 bq.bytes.len(),
521 (src1.len() + src2.len() - 1) / chunk_size + 1
522 );
523 assert_eq!(bq.total_allocations as usize, bq.bytes.len());
524
525 assert_eq!(8, bq.peek(&mut dst1[..]).unwrap().unwrap().0);
528 assert_eq!(10, bq.peek(&mut dst2[..]).unwrap().unwrap().0);
529
530 assert_eq!(dst1, [1, 2, 3, 4, 5, 6, 7, 8]);
531 assert_eq!(dst2, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
532 assert_eq!(bq.num_bytes(), src1.len() + src2.len());
533
534 dst1.fill(0);
537 dst2.fill(0);
538
539 assert_eq!(8, bq.pop(&mut dst1[..]).unwrap().unwrap().0);
540 assert_eq!(8, bq.pop(&mut dst2[..]).unwrap().unwrap().0);
541
542 assert_eq!(dst1, [1, 2, 3, 4, 5, 6, 7, 8]);
543 assert_eq!(dst2, [9, 10, 11, 12, 13, 51, 52, 53, 0, 0]);
544 assert_eq!(bq.num_bytes(), 0);
545 }
546
547 #[test]
548 fn test_bytequeue_packet() {
549 let mut bq = ByteQueue::new(5);
550
551 let src1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13];
552 let src2 = [51, 52, 53];
553 let mut dst1 = [0; 8];
554 let mut dst2 = [0; 10];
555
556 bq.push_packet(&src1[..], src1.len()).unwrap();
557 bq.push_packet(&[][..], 0).unwrap();
558 bq.push_packet(&src2[..], src2.len()).unwrap();
559
560 assert_eq!(bq.num_bytes(), src1.len() + src2.len());
563 assert_eq!(bq.bytes.len(), 3);
564 assert_eq!(bq.total_allocations, 3);
565
566 assert_eq!(8, bq.peek(&mut dst1[..]).unwrap().unwrap().0);
569 assert_eq!(10, bq.peek(&mut dst2[..]).unwrap().unwrap().0);
570 assert_eq!(10, bq.peek(&mut dst2[..]).unwrap().unwrap().0);
571
572 assert_eq!(dst1, [1, 2, 3, 4, 5, 6, 7, 8]);
573 assert_eq!(dst2, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
574 assert_eq!(bq.num_bytes(), src1.len() + src2.len());
575
576 dst1.fill(0);
579 dst2.fill(0);
580
581 assert_eq!(8, bq.pop(&mut dst1[..]).unwrap().unwrap().0);
582 assert_eq!(0, bq.pop(&mut dst2[..]).unwrap().unwrap().0);
583 assert_eq!(3, bq.pop(&mut dst2[..]).unwrap().unwrap().0);
584
585 assert_eq!(dst1, [1, 2, 3, 4, 5, 6, 7, 8]);
586 assert_eq!(dst2, [51, 52, 53, 0, 0, 0, 0, 0, 0, 0]);
587 assert_eq!(bq.num_bytes(), 0);
588 }
589
590 #[test]
591 fn test_bytequeue_combined_1() {
592 let mut bq = ByteQueue::new(10);
593
594 bq.push_stream(&[1, 2, 3][..]).unwrap();
595 bq.push_packet(&[4, 5, 6][..], 3).unwrap();
596 bq.push_stream(&[7, 8, 9][..]).unwrap();
597
598 assert_eq!(bq.num_bytes(), 9);
599 assert_eq!(bq.bytes.len(), 3);
600 assert_eq!(bq.total_allocations, 1);
601
602 let mut buf = [0; 20];
603
604 assert_eq!(
605 bq.pop(&mut buf[..]).unwrap(),
606 Some((3, 3, ChunkType::Stream))
607 );
608 assert_eq!(buf[..3], [1, 2, 3]);
609
610 assert_eq!(
611 bq.pop(&mut buf[..]).unwrap(),
612 Some((3, 3, ChunkType::Packet))
613 );
614 assert_eq!(buf[..3], [4, 5, 6]);
615
616 assert_eq!(
617 bq.pop(&mut buf[..]).unwrap(),
618 Some((3, 3, ChunkType::Stream))
619 );
620 assert_eq!(buf[..3], [7, 8, 9]);
621
622 assert!(!bq.has_bytes());
623 }
624
625 #[test]
626 fn test_bytequeue_combined_2() {
627 let mut bq = ByteQueue::new(5);
628
629 bq.push_stream(&[1, 2, 3, 4][..]).unwrap();
630 bq.push_stream(&[5][..]).unwrap();
631 bq.push_stream(&[6][..]).unwrap();
632 bq.push_packet(&[7, 8, 9, 10, 11, 12, 13, 14][..], 8)
633 .unwrap();
634 bq.push_stream(&[15, 16, 17][..]).unwrap();
635 bq.push_chunk(
636 Bytes::from_static(&[100, 101, 102, 103, 104, 105]),
637 ChunkType::Packet,
638 );
639 bq.push_packet(&[][..], 0).unwrap();
640 bq.push_stream(&[18][..]).unwrap();
641 bq.push_stream(&[19][..]).unwrap();
642 bq.push_stream(&[20, 21][..]).unwrap();
643
644 let mut buf = [0; 20];
645
646 assert_eq!(
647 bq.pop(&mut buf[..3]).unwrap(),
648 Some((3, 3, ChunkType::Stream))
649 );
650 assert_eq!(buf[..3], [1, 2, 3]);
651
652 assert_eq!(
653 bq.pop(&mut buf[..5]).unwrap(),
654 Some((3, 3, ChunkType::Stream))
655 );
656 assert_eq!(buf[..3], [4, 5, 6]);
657
658 assert_eq!(
659 bq.pop(&mut buf[..4]).unwrap(),
660 Some((4, 8, ChunkType::Packet))
661 );
662 assert_eq!(buf[..4], [7, 8, 9, 10]);
663
664 assert_eq!(
665 bq.pop(&mut buf[..4]).unwrap(),
666 Some((3, 3, ChunkType::Stream))
667 );
668 assert_eq!(buf[..3], [15, 16, 17]);
669
670 assert_eq!(
671 bq.pop(&mut buf[..4]).unwrap(),
672 Some((4, 6, ChunkType::Packet))
673 );
674 assert_eq!(buf[..4], [100, 101, 102, 103]);
675
676 assert_eq!(
677 bq.pop(&mut buf[..4]).unwrap(),
678 Some((0, 0, ChunkType::Packet))
679 );
680
681 assert_eq!(bq.pop_chunk(4), Some(([18][..].into(), ChunkType::Stream)));
682
683 assert_eq!(
684 bq.pop_chunk(4),
685 Some(([19, 20, 21][..].into(), ChunkType::Stream))
686 );
687
688 assert_eq!(bq.pop_chunk(8), None);
689 assert_eq!(bq.pop(&mut buf[..4]).unwrap(), None);
690 assert!(!bq.has_bytes());
691 }
692
693 #[test]
694 fn test_bytequeue_fallible_writer() {
695 struct TestWriter;
696
697 impl std::io::Write for TestWriter {
698 fn write(&mut self, _buf: &[u8]) -> std::io::Result<usize> {
699 Err(std::io::ErrorKind::BrokenPipe.into())
700 }
701 fn flush(&mut self) -> std::io::Result<()> {
702 Ok(())
703 }
704 }
705
706 let mut bq = ByteQueue::new(10);
707
708 bq.push_packet(&[4, 5, 6][..], 3).unwrap();
709 bq.push_stream(&[1, 2, 3][..]).unwrap();
710
711 let mut writer = TestWriter {};
712
713 bq.pop(&mut writer).unwrap_err();
715 bq.pop(&mut writer).unwrap_err();
717
718 assert_eq!(bq.num_bytes(), 3);
719 }
720
721 #[test]
723 fn test_bytequeue_peek() {
724 let mut rng = ChaCha20Rng::seed_from_u64(1234);
725
726 const PROB_PUSH: f64 = 0.8;
727 const PROB_POP: f64 = 0.9;
728 const PROB_STREAM: f64 = 0.5;
729 const MAX_PUSH: usize = 20;
730 const MAX_POP: usize = 30;
731
732 #[cfg(not(miri))]
734 const NUM_ITER: usize = 5000;
735 #[cfg(miri)]
736 const NUM_ITER: usize = 10;
737
738 static_assertions::const_assert!(PROB_POP > PROB_PUSH);
740 static_assertions::const_assert!(MAX_POP > MAX_PUSH);
741
742 let mut bq = ByteQueue::new(10);
743
744 for _ in 0..NUM_ITER {
745 if rng.random_bool(PROB_PUSH) {
747 let mut bytes = vec![0u8; rng.random_range(0..MAX_PUSH)];
748 rng.fill_bytes(&mut bytes);
749
750 if rng.random_bool(PROB_STREAM) {
751 bq.push_stream(&bytes[..]).unwrap();
752 } else {
753 bq.push_packet(&bytes[..], bytes.len()).unwrap();
754 }
755 }
756
757 let pop_size = rng.random_range(0..MAX_POP);
758
759 let mut peeked_bytes = vec![0u8; pop_size];
761 let peek_rv = bq.peek(&mut peeked_bytes[..]).unwrap();
762
763 if rng.random_bool(PROB_POP) {
765 let mut popped_bytes = vec![0u8; pop_size];
766 let pop_rv = bq.pop(&mut popped_bytes[..]).unwrap();
767
768 assert_eq!(peek_rv, pop_rv);
769 assert_eq!(popped_bytes, peeked_bytes);
770 }
771 }
772 }
773}