shadow_rs/utility/
pcap_writer.rs

1use std::io::{Seek, SeekFrom, Write};
2
3use crate::utility::give::Give;
4
5pub struct PcapWriter<W: Write> {
6    writer: W,
7    capture_len: u32,
8}
9
10impl<W: Write> PcapWriter<W> {
11    /// A new packet capture writer. Each packet (header and payload) captured will be truncated to
12    /// a length `capture_len`.
13    pub fn new(writer: W, capture_len: u32) -> std::io::Result<Self> {
14        let mut rv = PcapWriter {
15            writer,
16            capture_len,
17        };
18
19        rv.write_header()?;
20
21        Ok(rv)
22    }
23
24    fn write_header(&mut self) -> std::io::Result<()> {
25        // magic number to show endianness
26        const MAGIC_NUMBER: u32 = 0xA1B2C3D4;
27        const VERSION_MAJOR: u16 = 2;
28        const VERSION_MINOR: u16 = 4;
29        // GMT to local correction
30        const THIS_ZONE: i32 = 0;
31        // accuracy of timestamps
32        const SIG_FLAGS: u32 = 0;
33        // data link type (LINKTYPE_RAW)
34        const NETWORK: u32 = 101;
35
36        // magic number: 4 bytes
37        self.writer.write_all(&MAGIC_NUMBER.to_ne_bytes())?;
38        // major version: 2 bytes
39        self.writer.write_all(&VERSION_MAJOR.to_ne_bytes())?;
40        // minor version: 2 bytes
41        self.writer.write_all(&VERSION_MINOR.to_ne_bytes())?;
42
43        // GMT to local correction: 4 bytes
44        self.writer.write_all(&THIS_ZONE.to_ne_bytes())?;
45        // accuracy of timestamps: 4 bytes
46        self.writer.write_all(&SIG_FLAGS.to_ne_bytes())?;
47        // snapshot length: 4 bytes
48        self.writer.write_all(&self.capture_len.to_ne_bytes())?;
49        // link type: 4 bytes
50        self.writer.write_all(&NETWORK.to_ne_bytes())?;
51
52        Ok(())
53    }
54
55    /// Write a packet from a buffer.
56    pub fn write_packet(
57        &mut self,
58        ts_sec: u32,
59        ts_usec: u32,
60        packet: &[u8],
61    ) -> std::io::Result<()> {
62        let packet_len = u32::try_from(packet.len()).unwrap();
63        let packet_trunc_len = std::cmp::min(packet_len, self.capture_len);
64
65        // timestamp (seconds): 4 bytes
66        self.writer.write_all(&ts_sec.to_ne_bytes())?;
67        // timestamp (microseconds): 4 bytes
68        self.writer.write_all(&ts_usec.to_ne_bytes())?;
69
70        // captured packet length: 4 bytes
71        self.writer.write_all(&packet_trunc_len.to_ne_bytes())?;
72        // original packet length: 4 bytes
73        self.writer.write_all(&packet_len.to_ne_bytes())?;
74
75        // packet data: `packet_trunc_len` bytes
76        self.writer
77            .write_all(&packet[..(packet_trunc_len.try_into().unwrap())])?;
78
79        Ok(())
80    }
81}
82
83impl<W: Write + Seek> PcapWriter<W> {
84    /// Write a packet without requiring an intermediate buffer.
85    pub fn write_packet_fmt(
86        &mut self,
87        ts_sec: u32,
88        ts_usec: u32,
89        packet_len: u32,
90        write_packet_fn: impl FnOnce(&mut Give<&mut W>) -> std::io::Result<()>,
91    ) -> std::io::Result<()> {
92        // timestamp (seconds): 4 bytes
93        self.writer.write_all(&ts_sec.to_ne_bytes())?;
94        // timestamp (microseconds): 4 bytes
95        self.writer.write_all(&ts_usec.to_ne_bytes())?;
96
97        // position of the captured packet length field
98        let pos_of_len = self.writer.stream_position()?;
99
100        // captured packet length: 4 bytes
101        // (write initially as 0, we'll update it later)
102        self.writer.write_all(&0u32.to_ne_bytes())?;
103        // original packet length: 4 bytes
104        self.writer.write_all(&packet_len.to_ne_bytes())?;
105
106        // position of the packet data
107        let pos_before_packet_data = self.writer.stream_position()?;
108
109        // packet data: a soft limit of `capture_len` bytes
110        match write_packet_fn(&mut Give::new(&mut self.writer, self.capture_len as u64)) {
111            Ok(()) => {}
112            // this should mean that the entire packet couldn't be written, which is fine since
113            // we'll use a smaller captured packet length value
114            Err(e) if e.kind() == std::io::ErrorKind::WriteZero => {}
115            Err(e) => return Err(e),
116        }
117
118        // position after the packet data
119        let pos_after_packet_data = self.writer.stream_position()?;
120        // the number of packet data bytes written
121        let bytes_written = pos_after_packet_data - pos_before_packet_data;
122
123        // it is still possible for 'write_payload_fn' to have written more bytes than it was
124        // supposed to, so double check here
125        if bytes_written > self.capture_len.into() {
126            log::warn!(
127                "Pcap writer wrote more bytes than intended: {bytes_written} > {}",
128                self.capture_len
129            );
130            return Err(std::io::ErrorKind::InvalidData.into());
131        }
132
133        // go back and update the captured packet length
134        let bytes_written = u32::try_from(bytes_written).unwrap();
135        self.writer.seek(SeekFrom::Start(pos_of_len))?;
136        // captured packet length: 4 bytes
137        self.writer.write_all(&bytes_written.to_ne_bytes())?;
138        self.writer.seek(SeekFrom::Start(pos_after_packet_data))?;
139
140        Ok(())
141    }
142}
143
144pub trait PacketDisplay {
145    /// Write the packet bytes.
146    fn display_bytes(&self, writer: impl Write) -> std::io::Result<()>;
147}
148
149#[cfg(test)]
150mod tests {
151    use std::io::Cursor;
152
153    use super::*;
154
155    #[test]
156    fn test_empty_pcap_writer() {
157        let mut buf = vec![];
158        PcapWriter::new(&mut buf, 65535).unwrap();
159
160        let expected_header = [
161            0xD4, 0xC3, 0xB2, 0xA1, 0x02, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
162            0x00, 0x00, 0xFF, 0xFF, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00,
163        ];
164
165        assert_eq!(buf, expected_header);
166    }
167
168    #[test]
169    fn test_write_packet() {
170        let mut buf = vec![];
171        let mut pcap = PcapWriter::new(&mut buf, 65535).unwrap();
172        pcap.write_packet(32, 128, &[0x01, 0x02, 0x03]).unwrap();
173
174        let expected_header = [
175            0xD4, 0xC3, 0xB2, 0xA1, 0x02, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
176            0x00, 0x00, 0xFF, 0xFF, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00,
177        ];
178        let expected_packet_header = [
179            0x20, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00,
180            0x00, 0x00,
181        ];
182        let expected_payload = [0x01, 0x02, 0x03];
183
184        assert_eq!(
185            buf,
186            [
187                &expected_header[..],
188                &expected_packet_header[..],
189                &expected_payload[..]
190            ]
191            .concat()
192        );
193    }
194
195    #[test]
196    fn test_write_packet_fmt() {
197        let mut buf = Cursor::new(vec![]);
198        let mut pcap = PcapWriter::new(&mut buf, 65535).unwrap();
199        pcap.write_packet_fmt(32, 128, 3, |writer| {
200            writer.write_all(&[0x01])?;
201            writer.write_all(&[0x02])?;
202            writer.write_all(&[0x03])
203        })
204        .unwrap();
205
206        let expected_header = [
207            0xD4, 0xC3, 0xB2, 0xA1, 0x02, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
208            0x00, 0x00, 0xFF, 0xFF, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00,
209        ];
210        let expected_packet_header = [
211            0x20, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00,
212            0x00, 0x00,
213        ];
214        let expected_payload = [0x01, 0x02, 0x03];
215
216        let buf = buf.into_inner();
217
218        assert_eq!(
219            buf,
220            [
221                &expected_header[..],
222                &expected_packet_header[..],
223                &expected_payload[..]
224            ]
225            .concat()
226        );
227    }
228}