neli/
nl.rs

1//! This module contains the top level netlink header code. Every
2//! netlink message will be encapsulated in a top level `Nlmsghdr`.
3//!
4//! [`Nlmsghdr`] is the structure representing a
5//! header that all netlink protocols require to be passed to the
6//! correct destination.
7//!
8//! # Design decisions
9//!
10//! Payloads for [`Nlmsghdr`] can be any type.
11//!
12//! The payload is wrapped in an enum to facilitate better
13//! application-level error handling.
14
15use std::{
16    io::Cursor,
17    mem::{size_of, swap},
18};
19
20use derive_builder::{Builder, UninitializedFieldError};
21use getset::Getters;
22use log::trace;
23
24use crate::{
25    self as neli, FromBytes, FromBytesWithInput, Header, Size, ToBytes, TypeSize,
26    consts::nl::{NlType, NlmF, Nlmsg},
27    err::{DeError, Nlmsgerr, NlmsgerrBuilder, NlmsghdrAck, NlmsghdrErr, RouterError},
28    types::{Buffer, GenlBuffer},
29};
30
31/// An enum representing either the desired payload as requested
32/// by the payload type parameter, an ACK received at the end
33/// of a message or stream of messages, or an error.
34#[derive(Clone, Debug, PartialEq, Eq, Size, ToBytes)]
35pub enum NlPayload<T, P> {
36    /// Represents an ACK returned by netlink.
37    Ack(Nlmsgerr<NlmsghdrAck<T>>),
38    /// Represents an ACK extracted from the DONE packet returned by netlink
39    /// on a DUMP.
40    DumpExtAck(Nlmsgerr<()>),
41    /// Represents an application level error returned by netlink.
42    Err(Nlmsgerr<NlmsghdrErr<T, P>>),
43    /// Represents the requested payload.
44    Payload(P),
45    /// Indicates an empty payload.
46    Empty,
47}
48
49impl<T, P> FromBytesWithInput for NlPayload<T, P>
50where
51    P: Size + FromBytesWithInput<Input = usize>,
52    T: NlType,
53{
54    type Input = (usize, T, NlmF);
55
56    fn from_bytes_with_input(
57        buffer: &mut Cursor<impl AsRef<[u8]>>,
58        (input_size, input_type, flags): (usize, T, NlmF),
59    ) -> Result<Self, DeError> {
60        let pos = buffer.position();
61
62        let mut processing = || {
63            trace!("Deserializing data type {}", std::any::type_name::<Self>());
64            let ty_const: u16 = input_type.into();
65            if ty_const == Nlmsg::Done.into() {
66                if buffer.position() == buffer.get_ref().as_ref().len() as u64 {
67                    Ok(NlPayload::Empty)
68                } else if flags.contains(NlmF::MULTI) {
69                    trace!(
70                        "Deserializing field type {}",
71                        std::any::type_name::<Nlmsgerr<()>>(),
72                    );
73                    trace!("Input: {input_size:?}");
74                    let ext = Nlmsgerr::from_bytes_with_input(buffer, input_size)?;
75                    Ok(NlPayload::DumpExtAck(ext))
76                } else {
77                    // This is specifically targeting the connector protocol.
78                    // As more protocols are added, this may need to be changed.
79                    Ok(NlPayload::Payload(P::from_bytes_with_input(
80                        buffer, input_size,
81                    )?))
82                }
83            } else if ty_const == Nlmsg::Error.into() {
84                trace!(
85                    "Deserializing field type {}",
86                    std::any::type_name::<libc::c_int>()
87                );
88                let code = libc::c_int::from_bytes(buffer)?;
89                trace!("Field deserialized: {code:?}");
90                if code == 0 {
91                    trace!(
92                        "Deserializing field type {}",
93                        std::any::type_name::<NlmsghdrErr<T, ()>>()
94                    );
95                    trace!("Input: {input_size:?}");
96                    let nlmsg = NlmsghdrAck::<T>::from_bytes(buffer)?;
97                    trace!("Field deserialized: {nlmsg:?}");
98                    Ok(NlPayload::Ack(
99                        NlmsgerrBuilder::default().nlmsg(nlmsg).build()?,
100                    ))
101                } else {
102                    trace!(
103                        "Deserializing field type {}",
104                        std::any::type_name::<NlmsghdrErr<T, ()>>()
105                    );
106                    let nlmsg = NlmsghdrErr::<T, P>::from_bytes(buffer)?;
107                    trace!("Field deserialized: {nlmsg:?}");
108
109                    trace!(
110                        "Deserializing field type {}",
111                        std::any::type_name::<GenlBuffer<u16, Buffer>>()
112                    );
113                    let input = input_size - size_of::<libc::c_int>() - nlmsg.padded_size();
114                    trace!("Input: {input:?}");
115                    let ext_ack = GenlBuffer::from_bytes_with_input(buffer, input)?;
116                    trace!("Field deserialized: {ext_ack:?}");
117
118                    Ok(NlPayload::Err(
119                        NlmsgerrBuilder::default()
120                            .error(code)
121                            .nlmsg(nlmsg)
122                            .ext_ack(ext_ack)
123                            .build()?,
124                    ))
125                }
126            } else {
127                Ok(NlPayload::Payload(P::from_bytes_with_input(
128                    buffer, input_size,
129                )?))
130            }
131        };
132
133        match processing() {
134            Ok(o) => Ok(o),
135            Err(e) => {
136                buffer.set_position(pos);
137                Err(e)
138            }
139        }
140    }
141}
142
143/// Top level netlink header and payload
144#[derive(Builder, Getters, Clone, Debug, PartialEq, Eq, Size, ToBytes, FromBytes, Header)]
145#[neli(header_bound = "T: TypeSize")]
146#[neli(from_bytes_bound = "T: NlType")]
147#[neli(from_bytes_bound = "P: Size + FromBytesWithInput<Input = usize>")]
148#[neli(padding)]
149#[builder(build_fn(skip))]
150#[builder(pattern = "owned")]
151pub struct Nlmsghdr<T, P> {
152    /// Length of the netlink message
153    #[builder(setter(skip))]
154    #[getset(get = "pub")]
155    nl_len: u32,
156    /// Type of the netlink message
157    #[getset(get = "pub")]
158    nl_type: T,
159    /// Flags indicating properties of the request or response
160    #[getset(get = "pub")]
161    nl_flags: NlmF,
162    /// Sequence number for netlink protocol
163    #[getset(get = "pub")]
164    nl_seq: u32,
165    /// ID of the netlink destination for requests and source for
166    /// responses.
167    #[getset(get = "pub")]
168    nl_pid: u32,
169    /// Payload of netlink message
170    #[neli(input = "(nl_len as usize - Self::header_size() as usize, nl_type, nl_flags)")]
171    #[neli(size = "nl_len as usize - Self::header_size() as usize")]
172    #[getset(get = "pub")]
173    pub(crate) nl_payload: NlPayload<T, P>,
174}
175
176impl<T, P> NlmsghdrBuilder<T, P>
177where
178    T: NlType,
179    P: Size,
180{
181    /// Build [`Nlmsghdr`].
182    pub fn build(self) -> Result<Nlmsghdr<T, P>, NlmsghdrBuilderError> {
183        let nl_type = self
184            .nl_type
185            .ok_or_else(|| NlmsghdrBuilderError::from(UninitializedFieldError::new("nl_type")))?;
186        let nl_flags = self
187            .nl_flags
188            .ok_or_else(|| NlmsghdrBuilderError::from(UninitializedFieldError::new("nl_flags")))?;
189        let nl_seq = self.nl_seq.unwrap_or(0);
190        let nl_pid = self.nl_pid.unwrap_or(0);
191        let nl_payload = self.nl_payload.ok_or_else(|| {
192            NlmsghdrBuilderError::from(UninitializedFieldError::new("nl_payload"))
193        })?;
194
195        let mut nl = Nlmsghdr {
196            nl_len: 0,
197            nl_type,
198            nl_flags,
199            nl_seq,
200            nl_pid,
201            nl_payload,
202        };
203        nl.nl_len = nl.padded_size() as u32;
204        Ok(nl)
205    }
206}
207
208impl<T, P> Nlmsghdr<T, P>
209where
210    T: NlType,
211{
212    /// Get the payload if there is one or return an error.
213    pub fn get_payload(&self) -> Option<&P> {
214        match self.nl_payload {
215            NlPayload::Payload(ref p) => Some(p),
216            _ => None,
217        }
218    }
219
220    /// Get an error from the payload if it exists.
221    ///
222    /// Takes a mutable reference because the payload will be swapped for
223    /// [`Empty`][NlPayload::Empty] to gain ownership of the error.
224    pub fn get_err(&mut self) -> Option<Nlmsgerr<NlmsghdrErr<T, P>>> {
225        match self.nl_payload {
226            NlPayload::Err(_) => {
227                let mut payload = NlPayload::Empty;
228                swap(&mut self.nl_payload, &mut payload);
229                match payload {
230                    NlPayload::Err(e) => Some(e),
231                    _ => unreachable!(),
232                }
233            }
234            _ => None,
235        }
236    }
237}
238
239impl NlPayload<u16, Buffer> {
240    /// Convert a typed payload from a payload that can represent all types.
241    pub fn to_typed<T, P>(self, payload_size: usize) -> Result<NlPayload<T, P>, RouterError<T, P>>
242    where
243        T: NlType,
244        P: Size + FromBytesWithInput<Input = usize>,
245    {
246        match self {
247            NlPayload::Ack(a) => Ok(NlPayload::Ack(a.to_typed()?)),
248            NlPayload::Err(e) => Ok(NlPayload::Err(e.to_typed()?)),
249            NlPayload::DumpExtAck(a) => Ok(NlPayload::DumpExtAck(a)),
250            NlPayload::Payload(p) => Ok(NlPayload::Payload(P::from_bytes_with_input(
251                &mut Cursor::new(p),
252                payload_size,
253            )?)),
254            NlPayload::Empty => Ok(NlPayload::Empty),
255        }
256    }
257}
258
259impl<T, P> Nlmsghdr<T, P>
260where
261    T: NlType,
262    P: Size,
263{
264    /// Set the payload for [`Nlmsghdr`] and handle the change in length internally.
265    pub fn set_payload(&mut self, p: NlPayload<T, P>) {
266        self.nl_len -= self.nl_payload.padded_size() as u32;
267        self.nl_len += p.padded_size() as u32;
268        self.nl_payload = p;
269    }
270}
271
272impl Nlmsghdr<u16, Buffer> {
273    /// Set the payload for [`Nlmsghdr`] and handle the change in length internally.
274    pub fn to_typed<T, P>(self) -> Result<Nlmsghdr<T, P>, RouterError<T, P>>
275    where
276        T: NlType,
277        P: Size + FromBytesWithInput<Input = usize>,
278    {
279        Ok(NlmsghdrBuilder::default()
280            .nl_type(T::from(self.nl_type))
281            .nl_flags(self.nl_flags)
282            .nl_seq(self.nl_seq)
283            .nl_pid(self.nl_pid)
284            .nl_payload(
285                self.nl_payload
286                    .to_typed::<T, P>(self.nl_len as usize - Self::header_size())?,
287            )
288            .build()?)
289    }
290}