1use std::io::Write;
2use std::mem::MaybeUninit;
3use std::net::{IpAddr, SocketAddrV4};
4use std::sync::Arc;
5
6use crate::host::network::interface::FifoPacketPriority;
7use crate::utility::ObjectCounter;
8use crate::utility::pcap_writer::PacketDisplay;
9
10use atomic_refcell::AtomicRefCell;
11use bytes::Bytes;
12use shadow_shim_helper_rs::HostId;
13
14#[derive(Copy, Clone, Debug)]
16pub enum PacketStatus {
17    SndCreated,
18    SndTcpEnqueueThrottled,
19    SndTcpEnqueueRetransmit,
20    SndTcpDequeueRetransmit,
21    SndTcpRetransmitted,
22    SndSocketBuffered,
23    SndInterfaceSent,
24    InetSent,
25    InetDropped,
26    RouterEnqueued,
27    RouterDequeued,
28    RouterDropped,
29    RcvInterfaceReceived,
30    RcvInterfaceDropped,
31    RcvSocketProcessed,
32    RcvSocketDropped,
33    RcvTcpEnqueueUnordered,
34    RcvSocketBuffered,
35    RcvSocketDelivered,
36    Destroyed,
37    RelayCached,
38    RelayForwarded,
39}
40
41#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
43pub enum IanaProtocol {
44    Tcp,
45    Udp,
46}
47
48impl IanaProtocol {
49    pub fn number(&self) -> u8 {
52        match self {
56            IanaProtocol::Tcp => 6,
57            IanaProtocol::Udp => 17,
58        }
59    }
60}
61
62#[derive(Copy, Clone, Debug, PartialEq)]
67pub enum TypeOfService {
68    Normal,
70    MinimizeDelay,
72    MaximizeReliability,
74    MaximizeThroughput,
76}
77
78#[derive(Clone, Debug)]
96pub struct PacketRc {
97    inner: Arc<Packet>,
98}
99
100impl PacketRc {
101    pub fn new_ipv4_tcp(
107        header: tcp::TcpHeader,
108        payload: tcp::Payload,
109        priority: FifoPacketPriority,
110    ) -> Self {
111        Self::from(Packet::new_ipv4_tcp(header, payload, priority))
112    }
113
114    pub fn new_ipv4_udp(
120        src: SocketAddrV4,
121        dst: SocketAddrV4,
122        payload: Bytes,
123        priority: FifoPacketPriority,
124    ) -> Self {
125        Self::from(Packet::new_ipv4_udp(src, dst, payload, priority))
126    }
127
128    #[cfg(test)]
134    pub fn new_ipv4_udp_mock() -> Self {
135        Self::from(Packet::new_ipv4_udp_mock())
136    }
137
138    pub fn new_copy_inner(&self) -> Self {
144        Self::from(self.inner.as_ref().clone())
147    }
148
149    pub fn from_raw(packet_ptr: *mut Packet) -> Self {
164        assert!(!packet_ptr.is_null());
165        Self {
166            inner: unsafe { Arc::from_raw(packet_ptr) },
167        }
168    }
169
170    pub fn into_raw(self) -> *mut Packet {
185        Arc::into_raw(self.inner).cast_mut()
186    }
187
188    fn borrow_raw(packet_ptr: *const Packet) -> Self {
202        assert!(!packet_ptr.is_null());
203        unsafe { Arc::increment_strong_count(packet_ptr) };
204        PacketRc::from_raw(packet_ptr.cast_mut())
205    }
206
207    fn borrow_raw_mut(packet_ptr: *mut Packet) -> Self {
218        Self::borrow_raw(packet_ptr.cast_const())
219    }
220}
221
222impl PartialEq for PacketRc {
223    fn eq(&self, other: &Self) -> bool {
225        Arc::ptr_eq(&self.inner, &other.inner)
227    }
228}
229
230impl Eq for PacketRc {}
231
232impl From<Packet> for PacketRc {
233    fn from(packet: Packet) -> Self {
234        Self {
235            inner: Arc::new(packet),
236        }
237    }
238}
239
240impl std::ops::Deref for PacketRc {
242    type Target = Packet;
243    fn deref(&self) -> &Self::Target {
244        self.inner.as_ref()
245    }
246}
247
248#[derive(Clone, Debug)]
265pub struct Packet {
266    header: Header,
267    data: Data,
268    meta: Metadata,
269    _counter: ObjectCounter,
270}
271
272impl Packet {
273    fn new(header: Header, data: Data, meta: Metadata) -> Self {
277        Self {
278            header,
279            data,
280            meta,
281            _counter: ObjectCounter::new("Packet"),
282        }
283    }
284
285    pub fn new_ipv4_tcp(
287        header: tcp::TcpHeader,
288        payload: tcp::Payload,
289        priority: FifoPacketPriority,
290    ) -> Self {
291        let hdr = header;
292        let header = Header::new(IpAddr::V4(hdr.ip.src), IpAddr::V4(hdr.ip.dst));
293
294        let tcp_packet = TcpData::new(TcpHeader::from(hdr), payload.0);
295        let data = Data::from(tcp_packet);
296
297        let meta = Metadata::new(priority);
298
299        Self::new(header, data, meta)
300    }
301
302    pub fn new_ipv4_udp(
304        src: SocketAddrV4,
305        dst: SocketAddrV4,
306        payload: Bytes,
307        priority: FifoPacketPriority,
308    ) -> Self {
309        let header = Header::new(IpAddr::V4(*src.ip()), IpAddr::V4(*dst.ip()));
310
311        let udp_header = UdpHeader::new(src.port(), dst.port());
312        let udp_packet = UdpData::new(udp_header, payload);
313        let data = Data::from(udp_packet);
314
315        let meta = Metadata::new(priority);
316
317        Self::new(header, data, meta)
318    }
319
320    #[cfg(test)]
323    pub fn new_ipv4_udp_mock() -> Self {
324        let unspec = SocketAddrV4::new(std::net::Ipv4Addr::UNSPECIFIED, 0);
325        Self::new_ipv4_udp(unspec, unspec, Bytes::copy_from_slice(&[0; 1000]), 0)
327    }
328
329    pub fn ipv4_tcp_header(&self) -> Option<tcp::TcpHeader> {
336        let hdr = &self.header;
337
338        let IpAddr::V4(src) = hdr.src else {
339            return None;
340        };
341        let IpAddr::V4(dst) = hdr.dst else {
342            return None;
343        };
344
345        let tcp_hdr = match &self.data {
346            Data::LegacyTcp(_) => unimplemented!(),
348            Data::Tcp(tcp) => tcp.header.clone(),
349            Data::Udp(_) => return None,
350        };
351
352        Some(tcp::TcpHeader {
353            ip: tcp::Ipv4Header { src, dst },
354            flags: tcp_hdr.flags,
355            src_port: tcp_hdr.src_port,
356            dst_port: tcp_hdr.dst_port,
357            seq: tcp_hdr.sequence,
358            ack: tcp_hdr.acknowledgement,
359            window_size: tcp_hdr.window_size,
360            selective_acks: tcp_hdr.selective_acks.map(|x| x.into()),
361            window_scale: tcp_hdr.window_scale,
362            timestamp: tcp_hdr.timestamp,
363            timestamp_echo: tcp_hdr.timestamp_echo,
364        })
365    }
366
367    pub fn payload(&self) -> Vec<Bytes> {
378        match &self.data {
379            Data::LegacyTcp(tcp_rc) => tcp_rc.borrow().payload.clone(),
380            Data::Tcp(tcp) => tcp.payload.clone(),
381            Data::Udp(udp) => vec![udp.payload.clone()],
382        }
383    }
384
385    #[allow(clippy::len_without_is_empty)]
388    pub fn len(&self) -> usize {
389        self.header.len().checked_add(self.data.len()).unwrap()
390    }
391
392    pub fn payload_len(&self) -> usize {
395        self.data.payload_len()
396    }
397
398    pub fn add_status(&self, status: PacketStatus) {
402        if log::log_enabled!(log::Level::Trace) {
403            if let Some(vec) = self.meta.statuses.as_ref() {
404                vec.borrow_mut().push(status);
405            }
406            log::trace!("[{status:?}] {self:?}");
407        }
408    }
409
410    pub fn src_ipv4_address(&self) -> SocketAddrV4 {
416        let IpAddr::V4(addr) = self.header.src else {
417            unimplemented!()
418        };
419
420        let port = match &self.data {
421            Data::LegacyTcp(tcp_rc) => tcp_rc.borrow().header.src_port,
422            Data::Tcp(tcp) => tcp.header.src_port,
423            Data::Udp(udp) => udp.header.src_port,
424        };
425
426        SocketAddrV4::new(addr, port)
427    }
428
429    pub fn dst_ipv4_address(&self) -> SocketAddrV4 {
435        let IpAddr::V4(addr) = self.header.dst else {
436            unimplemented!()
437        };
438
439        let port = match &self.data {
440            Data::LegacyTcp(tcp_rc) => tcp_rc.borrow().header.dst_port,
441            Data::Tcp(tcp) => tcp.header.dst_port,
442            Data::Udp(udp) => udp.header.dst_port,
443        };
444
445        SocketAddrV4::new(addr, port)
446    }
447
448    pub fn priority(&self) -> FifoPacketPriority {
450        self.meta.priority
451    }
452
453    pub fn iana_protocol(&self) -> IanaProtocol {
455        self.data.iana_protocol()
456    }
457}
458
459#[derive(Clone, Debug)]
461struct Header {
462    src: IpAddr,
463    dst: IpAddr,
464    _tos: TypeOfService,
465}
466
467impl Header {
468    pub fn new(src: IpAddr, dst: IpAddr) -> Self {
469        Self {
471            src,
472            dst,
473            _tos: TypeOfService::Normal,
474        }
475    }
476
477    pub fn len(&self) -> usize {
478        match &self.dst {
479            IpAddr::V4(_) => 20usize,
481            IpAddr::V6(_) => 40usize,
483        }
484    }
485}
486
487#[derive(Clone, Debug)]
490enum Data {
491    LegacyTcp(AtomicRefCell<TcpData>),
494    Tcp(TcpData),
495    Udp(UdpData),
496}
497
498impl Data {
499    pub fn len(&self) -> usize {
500        match self {
501            Data::LegacyTcp(tcp_ref) => tcp_ref.borrow().len(),
502            Data::Tcp(tcp) => tcp.len(),
503            Data::Udp(udp) => udp.len(),
504        }
505    }
506
507    pub fn payload_len(&self) -> usize {
508        match self {
509            Data::LegacyTcp(tcp_ref) => tcp_ref.borrow().payload_len(),
510            Data::Tcp(tcp) => tcp.payload_len(),
511            Data::Udp(udp) => udp.payload_len(),
512        }
513    }
514
515    pub fn iana_protocol(&self) -> IanaProtocol {
516        match self {
517            Data::LegacyTcp(tcp_ref) => tcp_ref.borrow().iana_protocol(),
518            Data::Tcp(tcp) => tcp.iana_protocol(),
519            Data::Udp(udp) => udp.iana_protocol(),
520        }
521    }
522}
523
524impl From<UdpData> for Data {
525    fn from(packet: UdpData) -> Self {
526        Self::Udp(packet)
527    }
528}
529
530impl From<TcpData> for Data {
531    fn from(packet: TcpData) -> Self {
532        Self::Tcp(packet)
533    }
534}
535
536#[derive(Clone, Debug)]
539struct TcpData {
540    header: TcpHeader,
542    payload: Vec<Bytes>,
546}
547
548impl TcpData {
549    pub fn new(header: TcpHeader, payload: Vec<Bytes>) -> Self {
550        Self { header, payload }
551    }
552
553    pub fn len(&self) -> usize {
554        self.header.len().checked_add(self.payload_len()).unwrap()
555    }
556
557    pub fn payload_len(&self) -> usize {
558        self.payload
559            .iter()
560            .fold(0usize, |acc, x| acc.checked_add(x.len()).unwrap())
562    }
563
564    pub fn iana_protocol(&self) -> IanaProtocol {
565        IanaProtocol::Tcp
566    }
567}
568
569#[derive(Clone, Debug, PartialEq)]
572struct TcpHeader {
573    src_port: u16,
574    dst_port: u16,
575    flags: tcp::TcpFlags,
576    sequence: u32,
577    acknowledgement: u32,
578    window_size: u16,
579    selective_acks: Option<TcpSelectiveAcks>,
580    window_scale: Option<u8>,
581    timestamp: Option<u32>,
582    timestamp_echo: Option<u32>,
583}
584
585impl TcpHeader {
586    #[allow(dead_code)]
588    pub fn new(
589        src_port: u16,
590        dst_port: u16,
591        flags: tcp::TcpFlags,
592        sequence: u32,
593        acknowledgement: u32,
594        window_size: u16,
595        selective_acks: Option<TcpSelectiveAcks>,
596        window_scale: Option<u8>,
597        timestamp: Option<u32>,
598        timestamp_echo: Option<u32>,
599    ) -> Self {
600        Self {
601            src_port,
602            dst_port,
603            sequence,
604            flags,
605            acknowledgement,
606            window_size,
607            selective_acks,
608            window_scale,
609            timestamp,
610            timestamp_echo,
611        }
612    }
613
614    pub fn len(&self) -> usize {
618        let mut len = 20usize;
620
621        if self.window_scale.is_some() {
623            len += 3;
625        }
626
627        if !len.is_multiple_of(4) {
631            len += 4 - (len % 4);
632        }
633
634        len
635    }
636}
637
638impl From<tcp::TcpHeader> for TcpHeader {
639    fn from(hdr: tcp::TcpHeader) -> Self {
640        TcpHeader {
641            src_port: hdr.src_port,
642            dst_port: hdr.dst_port,
643            flags: hdr.flags,
644            sequence: hdr.seq,
645            acknowledgement: hdr.ack,
646            window_size: hdr.window_size,
647            selective_acks: hdr.selective_acks.map(|x| x.into()),
648            window_scale: hdr.window_scale,
649            timestamp: hdr.timestamp,
650            timestamp_echo: hdr.timestamp_echo,
651        }
652    }
653}
654
655#[derive(Clone, Copy, Debug, Default)]
656struct TcpSelectiveAcks {
657    len: u8,
658    ranges: [(u32, u32); 4],
660}
661
662impl From<tcp::util::SmallArrayBackedSlice<4, (u32, u32)>> for TcpSelectiveAcks {
663    fn from(array: tcp::util::SmallArrayBackedSlice<4, (u32, u32)>) -> Self {
664        let mut selective_acks = Self::default();
665
666        for (i, sack) in array.as_ref().iter().enumerate() {
667            selective_acks.ranges[i] = (sack.0, sack.1);
668            selective_acks.len += 1;
669            if selective_acks.len >= 4 {
670                break;
671            }
672        }
673
674        selective_acks
675    }
676}
677
678impl From<TcpSelectiveAcks> for tcp::util::SmallArrayBackedSlice<4, (u32, u32)> {
679    fn from(selective_acks: TcpSelectiveAcks) -> Self {
680        assert!(selective_acks.len <= 4);
681        Self::new(&selective_acks.ranges[0..(selective_acks.len as usize)]).unwrap()
682    }
683}
684
685impl PartialEq for TcpSelectiveAcks {
686    fn eq(&self, other: &Self) -> bool {
687        if self.len != other.len {
688            return false;
689        }
690        for i in 0..self.len as usize {
691            if self.ranges[i] != other.ranges[i] {
692                return false;
693            }
694        }
695        true
696    }
697}
698
699#[derive(Clone, Debug)]
702struct UdpData {
703    header: UdpHeader,
704    payload: Bytes,
705}
706
707impl UdpData {
708    pub fn new(header: UdpHeader, payload: Bytes) -> Self {
709        Self { header, payload }
710    }
711
712    pub fn len(&self) -> usize {
713        self.header.len().checked_add(self.payload_len()).unwrap()
714    }
715
716    pub fn payload_len(&self) -> usize {
717        self.payload.len()
718    }
719
720    pub fn iana_protocol(&self) -> IanaProtocol {
721        IanaProtocol::Udp
722    }
723}
724
725#[derive(Clone, Debug, PartialEq)]
727struct UdpHeader {
728    src_port: u16,
729    dst_port: u16,
730}
731
732impl UdpHeader {
733    pub fn new(src_port: u16, dst_port: u16) -> Self {
734        Self { src_port, dst_port }
735    }
736
737    pub fn len(&self) -> usize {
738        8usize
740    }
741}
742
743#[derive(Clone, Debug)]
744struct Metadata {
745    priority: FifoPacketPriority,
750    statuses: Option<AtomicRefCell<Vec<PacketStatus>>>,
753    _host_id: Option<HostId>,
760    _packet_id: Option<u64>,
767}
768
769impl Metadata {
770    pub fn new(priority: FifoPacketPriority) -> Self {
771        Self {
772            priority,
773            _host_id: None,
774            _packet_id: None,
775            statuses: log::log_enabled!(log::Level::Trace).then(AtomicRefCell::default),
778        }
779    }
780
781    fn new_legacy(priority: FifoPacketPriority, host_id: HostId, packet_id: u64) -> Self {
788        Self {
789            priority,
790            _host_id: Some(host_id),
791            _packet_id: Some(packet_id),
792            statuses: log::log_enabled!(log::Level::Trace).then(AtomicRefCell::default),
795        }
796    }
797}
798
799impl PacketDisplay for Packet {
800    fn display_bytes(&self, mut writer: impl Write) -> std::io::Result<()> {
801        let version_and_header_length: u8 = 0x45;
804        let fields: u8 = 0x0;
805        let total_length: u16 = self.len().try_into().unwrap();
806        let identification: u16 = 0x0;
807        let flags_and_fragment: u16 = 0x4000;
808        let time_to_live: u8 = 64;
809        let iana_protocol: u8 = self.data.iana_protocol().number();
810        let header_checksum: u16 = 0x0;
811        let source_ip: [u8; 4] = self.src_ipv4_address().ip().to_bits().to_be_bytes();
812        let dest_ip: [u8; 4] = self.dst_ipv4_address().ip().to_bits().to_be_bytes();
813
814        writer.write_all(&[version_and_header_length, fields])?;
817        writer.write_all(&total_length.to_be_bytes())?;
819        writer.write_all(&identification.to_be_bytes())?;
821        writer.write_all(&flags_and_fragment.to_be_bytes())?;
823        writer.write_all(&[time_to_live, iana_protocol])?;
826        writer.write_all(&header_checksum.to_be_bytes())?;
828        writer.write_all(&source_ip)?;
830        writer.write_all(&dest_ip)?;
832
833        match &self.data {
836            Data::LegacyTcp(tcp_ref) => write_tcpdata_bytes(&tcp_ref.borrow(), writer),
837            Data::Tcp(tcp) => write_tcpdata_bytes(tcp, writer),
838            Data::Udp(udp) => write_udpdata_bytes(udp, writer),
839        }?;
840
841        Ok(())
842    }
843}
844
845fn write_tcpdata_bytes(data: &TcpData, mut writer: impl Write) -> std::io::Result<()> {
846    let tcp_hdr = &data.header;
849
850    let mut options = [0u8; 40];
852    let mut options_len = 0;
853
854    if let Some(window_scale) = tcp_hdr.window_scale {
855        options[options_len..][..3].copy_from_slice(&[3, 3, window_scale]);
857        options_len += 3;
858    }
859
860    if options_len % 4 != 0 {
863        let padding = 4 - (options_len % 4);
865        options_len += padding;
866    }
867
868    let options = &options[..options_len];
869
870    let mut tcp_flags = tcp_hdr.flags;
873
874    let header_len: usize = 20usize.checked_add(options.len()).unwrap();
876    assert_eq!(header_len, tcp_hdr.len());
877
878    let mut header_len = u8::try_from(header_len).unwrap();
881    header_len /= 4;
882    header_len <<= 4;
883
884    tcp_flags.remove(tcp::TcpFlags::ECE);
887    tcp_flags.remove(tcp::TcpFlags::CWR);
888
889    writer.write_all(&tcp_hdr.src_port.to_be_bytes())?;
893    writer.write_all(&tcp_hdr.dst_port.to_be_bytes())?;
895    writer.write_all(&tcp_hdr.sequence.to_be_bytes())?;
897    writer.write_all(&tcp_hdr.acknowledgement.to_be_bytes())?;
899    writer.write_all(&[header_len, tcp_flags.bits()])?;
902    writer.write_all(&tcp_hdr.window_size.to_be_bytes())?;
904    let checksum: u16 = 0u16;
906    writer.write_all(&checksum.to_be_bytes())?;
907    let urgent_pointer: u16 = 0u16;
909    writer.write_all(&urgent_pointer.to_be_bytes())?;
910
911    writer.write_all(options)?;
912
913    for bytes in &data.payload {
916        writer.write_all(bytes)?;
917    }
918
919    Ok(())
920}
921
922fn write_udpdata_bytes(data: &UdpData, mut writer: impl Write) -> std::io::Result<()> {
923    writer.write_all(&data.header.src_port.to_be_bytes())?;
927    writer.write_all(&data.header.dst_port.to_be_bytes())?;
929    let udp_len: u16 = u16::try_from(data.len()).unwrap();
931    writer.write_all(&udp_len.to_be_bytes())?;
932    let checksum: u16 = 0x0;
934    writer.write_all(&checksum.to_be_bytes())?;
935
936    writer.write_all(&data.payload)?;
939
940    Ok(())
941}
942
943#[cfg(test)]
944mod tests {
945    use std::net::Ipv4Addr;
946
947    use super::*;
948
949    #[test]
950    fn ipv4_udp() {
951        let src = SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 1), 10_000);
952        let dst = SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 2), 80);
953        let payload = Bytes::from_static(b"Hello World!");
954        let priority = 123;
955
956        let packetrc = PacketRc::new_ipv4_udp(src, dst, payload.clone(), priority);
957
958        assert_eq!(src, packetrc.src_ipv4_address());
959        assert_eq!(dst, packetrc.dst_ipv4_address());
960        assert_eq!(priority, packetrc.priority());
961        assert_eq!(IanaProtocol::Udp, packetrc.iana_protocol());
962
963        assert_eq!(payload.len(), packetrc.payload_len());
964        let chunks = packetrc.payload();
965        assert_eq!(1, chunks.len());
966        assert_eq!(payload, chunks.first().unwrap());
967    }
968
969    #[test]
970    fn ipv4_udp_empty() {
971        let src = SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 1), 10_000);
972        let dst = SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 2), 80);
973        let priority = 123;
974
975        let payload = Bytes::new();
978        let packetrc = PacketRc::new_ipv4_udp(src, dst, payload.clone(), priority);
979
980        assert_eq!(0, packetrc.payload_len());
981        assert_eq!(payload.len(), packetrc.payload_len());
982        let chunks = packetrc.payload();
983        assert_eq!(1, chunks.len());
984        assert_eq!(0, chunks.first().unwrap().len());
985    }
986
987    fn make_tcp_header(src: SocketAddrV4, dst: SocketAddrV4) -> tcp::TcpHeader {
988        let sel_acks =
990            tcp::util::SmallArrayBackedSlice::<4, (u32, u32)>::new(&[(1, 3), (5, 6)]).unwrap();
991
992        tcp::TcpHeader {
993            ip: tcp::Ipv4Header {
994                src: *src.ip(),
995                dst: *dst.ip(),
996            },
997            flags: tcp::TcpFlags::SYN,
998            src_port: src.port(),
999            dst_port: dst.port(),
1000            seq: 10,
1001            ack: 3,
1002            window_size: 25,
1003            selective_acks: Some(sel_acks),
1004            window_scale: Some(2),
1005            timestamp: Some(123456),
1006            timestamp_echo: Some(123450),
1007        }
1008    }
1009
1010    #[test]
1011    fn ipv4_tcp() {
1012        let src = SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 1), 10_000);
1013        let dst = SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 2), 80);
1014        let priority = 123;
1015        let tcp_hdr = make_tcp_header(src, dst);
1016        let payload = tcp::Payload(vec![
1017            Bytes::from_static(b"Hello"),
1018            Bytes::from_static(b" World!"),
1019        ]);
1020
1021        let packetrc = PacketRc::new_ipv4_tcp(tcp_hdr, payload.clone(), priority);
1022
1023        assert_eq!(src, packetrc.src_ipv4_address());
1024        assert_eq!(dst, packetrc.dst_ipv4_address());
1025        assert_eq!(priority, packetrc.priority());
1026        assert_eq!(IanaProtocol::Tcp, packetrc.iana_protocol());
1027        assert_eq!(
1028            TcpHeader::from(tcp_hdr),
1029            TcpHeader::from(packetrc.ipv4_tcp_header().unwrap())
1030        );
1031
1032        assert_eq!(payload.len() as usize, packetrc.payload_len());
1033        let chunks = packetrc.payload();
1034        assert_eq!(2, chunks.len());
1035
1036        for (i, bytes) in chunks.iter().enumerate() {
1037            assert_eq!(payload.0[i], bytes);
1038        }
1039    }
1040
1041    #[test]
1042    fn ipv4_tcp_empty() {
1043        let src = SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 1), 10_000);
1044        let dst = SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 2), 80);
1045        let priority = 123;
1046        let tcp_hdr = make_tcp_header(src, dst);
1047
1048        let payload = tcp::Payload(vec![]);
1051        let packetrc = PacketRc::new_ipv4_tcp(tcp_hdr, payload, priority);
1052
1053        assert_eq!(0, packetrc.payload_len());
1054        let chunks = packetrc.payload();
1055        assert_eq!(0, chunks.len());
1056
1057        let payload = tcp::Payload(vec![Bytes::new(), Bytes::new()]);
1060        let packetrc = PacketRc::new_ipv4_tcp(tcp_hdr, payload, priority);
1061
1062        assert_eq!(0, packetrc.payload_len());
1063        let chunks = packetrc.payload();
1064        assert_eq!(2, chunks.len());
1065        assert_eq!(0, chunks.first().unwrap().len());
1066        assert_eq!(0, chunks.last().unwrap().len());
1067    }
1068}
1069
1070mod export {
1077    use std::cmp::Ordering;
1078    use std::io::Write;
1079
1080    use shadow_shim_helper_rs::simulation_time::SimulationTime;
1081    use shadow_shim_helper_rs::syscall_types::UntypedForeignPtr;
1082
1083    use crate::cshadow as c;
1084    use crate::host::memory_manager::MemoryManager;
1085    use crate::host::syscall::types::ForeignArrayPtr;
1086
1087    use super::*;
1088
1089    #[unsafe(no_mangle)]
1090    pub extern "C-unwind" fn packet_new_tcp(
1091        host_id: HostId,
1092        packet_id: u64,
1093        flags: c::ProtocolTCPFlags,
1094        src_ip: libc::in_addr_t,
1095        src_port: libc::in_port_t,
1096        dst_ip: libc::in_addr_t,
1097        dst_port: libc::in_port_t,
1098        seq: u32,
1099        priority: u64,
1100    ) -> *mut Packet {
1101        let header = Header::new(
1103            IpAddr::V4(u32::from_be(src_ip).into()),
1104            IpAddr::V4(u32::from_be(dst_ip).into()),
1105        );
1106
1107        let data = Data::LegacyTcp(AtomicRefCell::new(TcpData {
1109            header: TcpHeader {
1110                src_port: u16::from_be(src_port),
1111                dst_port: u16::from_be(dst_port),
1112                flags: legacy_flags_to_tcp_flags(flags),
1113                sequence: seq,
1114                acknowledgement: 0,
1115                window_size: 0,
1116                selective_acks: None,
1117                window_scale: None,
1118                timestamp: None,
1119                timestamp_echo: None,
1120            },
1121            payload: vec![],
1122        }));
1123
1124        let meta = Metadata::new_legacy(priority, host_id, packet_id);
1125        let packet = Packet::new(header, data, meta);
1126
1127        PacketRc::from(packet).into_raw()
1129    }
1130
1131    #[unsafe(no_mangle)]
1132    pub extern "C-unwind" fn packet_ref(packet_ptr: *mut Packet) {
1133        assert!(!packet_ptr.is_null());
1134        unsafe { Arc::increment_strong_count(packet_ptr) };
1135    }
1136
1137    #[unsafe(no_mangle)]
1138    pub extern "C-unwind" fn packet_unref(packet_ptr: *mut Packet) {
1139        assert!(!packet_ptr.is_null());
1140        unsafe { Arc::decrement_strong_count(packet_ptr) };
1141    }
1142
1143    #[unsafe(no_mangle)]
1144    pub extern "C-unwind" fn packet_updateTCP(
1145        packet_ptr: *mut Packet,
1146        ack: libc::c_uint,
1147        sel_acks: c::PacketSelectiveAcks,
1148        window_size: libc::c_uint,
1149        window_scale: libc::c_uchar,
1150        window_scale_set: bool,
1151        ts_val: c::CSimulationTime,
1152        ts_echo: c::CSimulationTime,
1153    ) {
1154        let packet = PacketRc::borrow_raw_mut(packet_ptr);
1155
1156        let Data::LegacyTcp(tcp) = &packet.data else {
1157            unimplemented!()
1158        };
1159
1160        let mut tcp = tcp.borrow_mut();
1161
1162        tcp.header.acknowledgement = ack;
1163        tcp.header.selective_acks = Some(TcpSelectiveAcks::from(sel_acks));
1164
1165        tcp.header.window_size = u16::try_from(window_size).unwrap_or(u16::MAX);
1166        if window_scale_set {
1167            tcp.header.window_scale = Some(window_scale);
1168        } else {
1169            tcp.header.window_scale = None;
1170        }
1171
1172        tcp.header.timestamp = from_legacy_timestamp(ts_val);
1176        tcp.header.timestamp_echo = from_legacy_timestamp(ts_echo);
1177    }
1178
1179    #[unsafe(no_mangle)]
1180    pub extern "C-unwind" fn packet_getTCPHeader(packet_ptr: *const Packet) -> c::PacketTCPHeader {
1181        let packet = PacketRc::borrow_raw(packet_ptr);
1182
1183        let IpAddr::V4(src_ip) = packet.header.src else {
1184            unimplemented!()
1185        };
1186        let IpAddr::V4(dst_ip) = packet.header.dst else {
1187            unimplemented!()
1188        };
1189        let Data::LegacyTcp(tcp_rc) = &packet.data else {
1190            unimplemented!()
1191        };
1192        let tcp = tcp_rc.borrow();
1193
1194        let mut c_hdr: c::PacketTCPHeader = unsafe { MaybeUninit::zeroed().assume_init() };
1195
1196        c_hdr.flags = tcp_flags_to_legacy_flags(tcp.header.flags);
1197        c_hdr.sourceIP = u32::from(src_ip).to_be();
1198        c_hdr.sourcePort = tcp.header.src_port.to_be();
1199        c_hdr.destinationIP = u32::from(dst_ip).to_be();
1200        c_hdr.destinationPort = tcp.header.dst_port.to_be();
1201        c_hdr.sequence = tcp.header.sequence;
1202        c_hdr.acknowledgment = tcp.header.acknowledgement;
1203        c_hdr.selectiveACKs = to_legacy_sel_acks(tcp.header.selective_acks);
1204        c_hdr.window = u32::from(tcp.header.window_size);
1205        if let Some(scale) = tcp.header.window_scale {
1206            c_hdr.windowScale = scale;
1207            c_hdr.windowScaleSet = true;
1208        }
1209        c_hdr.timestampValue = to_legacy_timestamp(tcp.header.timestamp);
1210        c_hdr.timestampEcho = to_legacy_timestamp(tcp.header.timestamp_echo);
1211
1212        c_hdr
1213    }
1214
1215    #[unsafe(no_mangle)]
1216    pub extern "C-unwind" fn packet_appendPayloadWithMemoryManager(
1217        packet_ptr: *mut Packet,
1218        src: UntypedForeignPtr,
1219        src_len: u64,
1220        mem: *const MemoryManager,
1221    ) {
1222        let packet = PacketRc::borrow_raw_mut(packet_ptr);
1224        let mem = unsafe { mem.as_ref() }.unwrap();
1225
1226        let Data::LegacyTcp(tcp) = &packet.data else {
1227            unimplemented!()
1228        };
1229
1230        let len = usize::try_from(src_len).unwrap();
1231        let src = ForeignArrayPtr::new(src.cast::<MaybeUninit<u8>>(), len);
1232
1233        let mut dst = Box::<[u8]>::new_uninit_slice(len);
1237
1238        log::trace!(
1239            "Requested to read payload of len {len} from the managed process into the packet's \
1240            payload buffer",
1241        );
1242
1243        if let Err(e) = mem.copy_from_ptr(&mut dst[..], src) {
1247            panic!(
1249                "Couldn't read managed process memory at {src:?} into packet payload at {dst:?}: {e:?}"
1250            );
1251        }
1252
1253        let dst = unsafe { dst.assume_init() };
1254
1255        log::trace!(
1256            "We read {} bytes from the managed process into the packet's payload",
1257            dst.len()
1258        );
1259
1260        tcp.borrow_mut().payload.push(Bytes::from(dst));
1262    }
1263
1264    #[unsafe(no_mangle)]
1265    pub extern "C-unwind" fn packet_copyPayloadWithMemoryManager(
1266        packet_ptr: *const Packet,
1267        payload_offset: u64,
1268        dst: UntypedForeignPtr,
1269        dst_len: u64,
1270        mem: *mut MemoryManager,
1271    ) -> i64 {
1272        let packet = PacketRc::borrow_raw(packet_ptr);
1274        let mem = unsafe { mem.as_mut() }.unwrap();
1275
1276        let Data::LegacyTcp(tcp) = &packet.data else {
1277            unimplemented!()
1278        };
1279
1280        if dst_len == 0 {
1281            return 0;
1282        }
1283
1284        log::trace!(
1285            "Requested to write payload of len {} from offset {payload_offset} into managed \
1286            process buffer of len {dst_len}",
1287            packet.payload_len()
1288        );
1289
1290        let dst_len = usize::try_from(dst_len).unwrap_or(usize::MAX);
1291        let dst = ForeignArrayPtr::new(dst.cast::<u8>(), dst_len);
1292
1293        let mut dst_writer = mem.writer(dst);
1294        let mut dst_space = dst_len;
1295        let mut src_offset = usize::try_from(payload_offset).unwrap_or(usize::MAX);
1296
1297        for bytes in &tcp.borrow().payload {
1298            if src_offset >= bytes.len() {
1300                src_offset = src_offset.saturating_sub(bytes.len());
1301                continue;
1302            }
1303
1304            let start = src_offset;
1305            let len = bytes.len().saturating_sub(start).min(dst_space);
1306            let end = start + len;
1307            assert!(start <= end);
1308
1309            if len == 0 {
1310                break;
1311            }
1312
1313            log::trace!("Writing {len} bytes into managed process");
1314
1315            if let Err(e) = dst_writer.write_all(&bytes[start..end]) {
1316                log::warn!(
1317                    "Couldn't write managed process memory at {dst:?} from packet payload: {e:?}"
1318                );
1319                return linux_api::errno::Errno::EFAULT.to_negated_i64();
1321            }
1322
1323            dst_space = dst_space.saturating_sub(len);
1324            src_offset = 0;
1325        }
1326
1327        let tot_written = dst_len.saturating_sub(dst_space);
1328
1329        if tot_written > 0
1330            && let Err(e) = dst_writer.flush()
1331        {
1332            log::warn!("Couldn't flush managed process writes from packet payload: {e:?}");
1333            return linux_api::errno::Errno::EFAULT.to_negated_i64();
1335        }
1336
1337        log::trace!("We wrote {tot_written} bytes into managed process buffer of len {dst_len}");
1338
1339        i64::try_from(tot_written).unwrap()
1340    }
1341
1342    #[unsafe(no_mangle)]
1343    pub extern "C-unwind" fn packet_getPriority(packet_ptr: *const Packet) -> u64 {
1344        let packet = PacketRc::borrow_raw(packet_ptr);
1345        packet.priority()
1346    }
1347
1348    #[unsafe(no_mangle)]
1349    pub extern "C-unwind" fn packet_getPayloadSize(packet_ptr: *const Packet) -> u64 {
1350        let packet = PacketRc::borrow_raw(packet_ptr);
1351        packet.payload_len().try_into().unwrap()
1352    }
1353
1354    #[unsafe(no_mangle)]
1355    pub extern "C-unwind" fn packet_getDestinationIP(packet_ptr: *const Packet) -> libc::in_addr_t {
1356        let packet = PacketRc::borrow_raw(packet_ptr);
1357        u32::to_be((*packet.dst_ipv4_address().ip()).into())
1358    }
1359
1360    #[unsafe(no_mangle)]
1361    pub extern "C-unwind" fn packet_getDestinationPort(
1362        packet_ptr: *const Packet,
1363    ) -> libc::in_port_t {
1364        let packet = PacketRc::borrow_raw(packet_ptr);
1365        u16::to_be(packet.dst_ipv4_address().port())
1366    }
1367
1368    #[unsafe(no_mangle)]
1369    pub extern "C-unwind" fn packet_getSourceIP(packet_ptr: *const Packet) -> libc::in_addr_t {
1370        let packet = PacketRc::borrow_raw(packet_ptr);
1371        u32::to_be((*packet.src_ipv4_address().ip()).into())
1372    }
1373
1374    #[unsafe(no_mangle)]
1375    pub extern "C-unwind" fn packet_getSourcePort(packet_ptr: *const Packet) -> libc::in_port_t {
1376        let packet = PacketRc::borrow_raw(packet_ptr);
1377        u16::to_be(packet.src_ipv4_address().port())
1378    }
1379
1380    #[unsafe(no_mangle)]
1381    pub extern "C-unwind" fn packet_addDeliveryStatus(
1382        packet_ptr: *mut Packet,
1383        status: c::PacketDeliveryStatusFlags,
1384    ) {
1385        let packet = PacketRc::borrow_raw_mut(packet_ptr);
1386        packet.add_status(PacketStatus::from(status));
1387    }
1388
1389    #[unsafe(no_mangle)]
1390    pub extern "C-unwind" fn packet_compareTCPSequence(
1391        packet_ptr1: *mut Packet,
1392        packet_ptr2: *mut Packet,
1393        _ptr: *mut libc::c_void,
1394    ) -> libc::c_int {
1395        let packet1 = PacketRc::borrow_raw_mut(packet_ptr1);
1396        let packet2 = PacketRc::borrow_raw_mut(packet_ptr2);
1397
1398        let seq1 = get_sequence_number(&packet1);
1399        let seq2 = get_sequence_number(&packet2);
1400
1401        match seq1.cmp(&seq2) {
1403            Ordering::Less => -1,
1404            Ordering::Equal => 0,
1405            Ordering::Greater => 1,
1406        }
1407    }
1408
1409    fn get_sequence_number(packet: &PacketRc) -> u32 {
1410        let Data::LegacyTcp(tcp_ref) = &packet.data else {
1411            unimplemented!()
1412        };
1413        tcp_ref.borrow().header.sequence
1414    }
1415
1416    fn legacy_flags_to_tcp_flags(legacy_flags: c::ProtocolTCPFlags) -> tcp::TcpFlags {
1417        let mut tcp_flags = tcp::TcpFlags::empty();
1421
1422        if legacy_flags & c::_ProtocolTCPFlags_PTCP_FIN != 0 {
1423            tcp_flags.insert(tcp::TcpFlags::FIN);
1424        }
1425        if legacy_flags & c::_ProtocolTCPFlags_PTCP_SYN != 0 {
1426            tcp_flags.insert(tcp::TcpFlags::SYN);
1427        }
1428        if legacy_flags & c::_ProtocolTCPFlags_PTCP_RST != 0 {
1429            tcp_flags.insert(tcp::TcpFlags::RST);
1430        }
1431        if legacy_flags & c::_ProtocolTCPFlags_PTCP_ACK != 0 {
1432            tcp_flags.insert(tcp::TcpFlags::ACK);
1433        }
1434        if legacy_flags & c::_ProtocolTCPFlags_PTCP_SACK != 0 {
1437            tcp_flags.insert(tcp::TcpFlags::ECE);
1438        }
1439        if legacy_flags & c::_ProtocolTCPFlags_PTCP_DUPACK != 0 {
1440            tcp_flags.insert(tcp::TcpFlags::CWR);
1441        }
1442
1443        tcp_flags
1444    }
1445
1446    fn tcp_flags_to_legacy_flags(tcp_flags: tcp::TcpFlags) -> c::ProtocolTCPFlags {
1447        let mut legacy_flags = c::_ProtocolTCPFlags_PTCP_NONE;
1451
1452        if tcp_flags.contains(tcp::TcpFlags::FIN) {
1453            legacy_flags |= c::_ProtocolTCPFlags_PTCP_FIN;
1454        }
1455        if tcp_flags.contains(tcp::TcpFlags::SYN) {
1456            legacy_flags |= c::_ProtocolTCPFlags_PTCP_SYN;
1457        }
1458        if tcp_flags.contains(tcp::TcpFlags::RST) {
1459            legacy_flags |= c::_ProtocolTCPFlags_PTCP_RST;
1460        }
1461        if tcp_flags.contains(tcp::TcpFlags::ACK) {
1462            legacy_flags |= c::_ProtocolTCPFlags_PTCP_ACK;
1463        }
1464        if tcp_flags.contains(tcp::TcpFlags::ECE) {
1466            legacy_flags |= c::_ProtocolTCPFlags_PTCP_SACK;
1467        }
1468        if tcp_flags.contains(tcp::TcpFlags::CWR) {
1469            legacy_flags |= c::_ProtocolTCPFlags_PTCP_DUPACK;
1470        }
1471
1472        legacy_flags
1473    }
1474
1475    fn from_legacy_timestamp(ts: c::CSimulationTime) -> Option<u32> {
1476        SimulationTime::from_c_simtime(ts).map(|x| u32::try_from(x.as_millis()).unwrap_or(u32::MAX))
1477    }
1478
1479    fn to_legacy_timestamp(val: Option<u32>) -> c::CSimulationTime {
1480        SimulationTime::to_c_simtime(val.map(|x| SimulationTime::from_millis(x as u64)))
1481    }
1482
1483    impl From<c::PacketSelectiveAcks> for TcpSelectiveAcks {
1484        fn from(c_sel_acks: c::PacketSelectiveAcks) -> Self {
1485            let mut selective_acks = TcpSelectiveAcks::default();
1486
1487            assert!(c_sel_acks.len <= 4);
1488
1489            for i in 0..(c_sel_acks.len as usize) {
1490                let start: u32 = c_sel_acks.ranges[i].start;
1491                let end: u32 = c_sel_acks.ranges[i].end;
1492                selective_acks.ranges[i] = (start, end);
1493                selective_acks.len += 1;
1494            }
1495
1496            selective_acks
1497        }
1498    }
1499
1500    fn to_legacy_sel_acks(selective_acks: Option<TcpSelectiveAcks>) -> c::PacketSelectiveAcks {
1501        let mut c_sel_acks: c::PacketSelectiveAcks = unsafe { MaybeUninit::zeroed().assume_init() };
1502
1503        let Some(selective_acks) = selective_acks else {
1504            return c_sel_acks;
1505        };
1506
1507        assert!(selective_acks.len <= 4);
1508
1509        for i in 0..(selective_acks.len as usize) {
1510            let (start, end) = selective_acks.ranges[i];
1511            c_sel_acks.ranges[i].start = start;
1512            c_sel_acks.ranges[i].end = end;
1513            c_sel_acks.len += 1;
1514        }
1515
1516        c_sel_acks
1517    }
1518
1519    impl From<c::ProtocolType> for IanaProtocol {
1520        fn from(value: c::ProtocolType) -> Self {
1521            match value {
1522                c::_ProtocolType_PTCP => IanaProtocol::Tcp,
1523                c::_ProtocolType_PUDP => IanaProtocol::Udp,
1524                _ => panic!("Unexpected protocol type {value}"),
1525            }
1526        }
1527    }
1528
1529    impl From<c::PacketDeliveryStatusFlags> for PacketStatus {
1530        fn from(legacy_status: c::PacketDeliveryStatusFlags) -> Self {
1531            match legacy_status {
1532                c::_PacketDeliveryStatusFlags_PDS_SND_CREATED => PacketStatus::SndCreated,
1533                c::_PacketDeliveryStatusFlags_PDS_SND_TCP_ENQUEUE_THROTTLED => {
1534                    PacketStatus::SndTcpEnqueueThrottled
1535                }
1536                c::_PacketDeliveryStatusFlags_PDS_SND_TCP_ENQUEUE_RETRANSMIT => {
1537                    PacketStatus::SndTcpEnqueueRetransmit
1538                }
1539                c::_PacketDeliveryStatusFlags_PDS_SND_TCP_DEQUEUE_RETRANSMIT => {
1540                    PacketStatus::SndTcpDequeueRetransmit
1541                }
1542                c::_PacketDeliveryStatusFlags_PDS_SND_TCP_RETRANSMITTED => {
1543                    PacketStatus::SndTcpRetransmitted
1544                }
1545                c::_PacketDeliveryStatusFlags_PDS_SND_SOCKET_BUFFERED => {
1546                    PacketStatus::SndSocketBuffered
1547                }
1548                c::_PacketDeliveryStatusFlags_PDS_SND_INTERFACE_SENT => {
1549                    PacketStatus::SndInterfaceSent
1550                }
1551                c::_PacketDeliveryStatusFlags_PDS_INET_SENT => PacketStatus::InetSent,
1552                c::_PacketDeliveryStatusFlags_PDS_INET_DROPPED => PacketStatus::InetDropped,
1553                c::_PacketDeliveryStatusFlags_PDS_ROUTER_ENQUEUED => PacketStatus::RouterEnqueued,
1554                c::_PacketDeliveryStatusFlags_PDS_ROUTER_DEQUEUED => PacketStatus::RouterDequeued,
1555                c::_PacketDeliveryStatusFlags_PDS_ROUTER_DROPPED => PacketStatus::RouterDropped,
1556                c::_PacketDeliveryStatusFlags_PDS_RCV_INTERFACE_RECEIVED => {
1557                    PacketStatus::RcvInterfaceReceived
1558                }
1559                c::_PacketDeliveryStatusFlags_PDS_RCV_INTERFACE_DROPPED => {
1560                    PacketStatus::RcvInterfaceDropped
1561                }
1562                c::_PacketDeliveryStatusFlags_PDS_RCV_SOCKET_PROCESSED => {
1563                    PacketStatus::RcvSocketProcessed
1564                }
1565                c::_PacketDeliveryStatusFlags_PDS_RCV_SOCKET_DROPPED => {
1566                    PacketStatus::RcvSocketDropped
1567                }
1568                c::_PacketDeliveryStatusFlags_PDS_RCV_TCP_ENQUEUE_UNORDERED => {
1569                    PacketStatus::RcvTcpEnqueueUnordered
1570                }
1571                c::_PacketDeliveryStatusFlags_PDS_RCV_SOCKET_BUFFERED => {
1572                    PacketStatus::RcvSocketBuffered
1573                }
1574                c::_PacketDeliveryStatusFlags_PDS_RCV_SOCKET_DELIVERED => {
1575                    PacketStatus::RcvSocketDelivered
1576                }
1577                c::_PacketDeliveryStatusFlags_PDS_DESTROYED => PacketStatus::Destroyed,
1578                c::_PacketDeliveryStatusFlags_PDS_RELAY_CACHED => PacketStatus::RelayCached,
1579                c::_PacketDeliveryStatusFlags_PDS_RELAY_FORWARDED => PacketStatus::RelayForwarded,
1580                _ => unimplemented!(),
1581            }
1582        }
1583    }
1584}