neli/
nl.rs

1//! This module contains the top level netlink header code. Every
2//! netlink message will be encapsulated in a top level `Nlmsghdr`.
3//!
4//! [`Nlmsghdr`][crate::nl::Nlmsghdr] is the structure representing a
5//! header that all netlink protocols require to be passed to the
6//! correct destination.
7//!
8//! # Design decisions
9//!
10//! Payloads for [`Nlmsghdr`][crate::nl::Nlmsghdr] can be any type.
11//!
12//! The payload is wrapped in an enum to facilitate better
13//! application-level error handling.
14
15use crate as neli;
16
17use std::{
18    any::type_name,
19    io::{Cursor, Read},
20};
21
22use log::trace;
23
24use crate::{
25    consts::nl::{NlType, NlmFFlags, Nlmsg},
26    err::{DeError, NlError, Nlmsgerr, NlmsghdrErr},
27    FromBytes, FromBytesWithInput, Header, Size, ToBytes, TypeSize,
28};
29
30/// An enum representing either the desired payload as requested
31/// by the payload type parameter, an ACK received at the end
32/// of a message or stream of messages, or an error.
33#[derive(Debug, PartialEq, Eq, Size, ToBytes)]
34pub enum NlPayload<T, P> {
35    /// Represents an ACK returned by netlink.
36    Ack(Nlmsgerr<T, ()>),
37    /// Represents an application level error returned by netlink.
38    Err(Nlmsgerr<T, P>),
39    /// Represents the requested payload.
40    Payload(P),
41    /// Indicates an empty payload.
42    Empty,
43}
44
45impl<T, P> NlPayload<T, P> {
46    /// Get the payload of the netlink packet and return [`None`]
47    /// if the contained data in the payload is actually an ACK
48    /// or an error.
49    pub fn get_payload(&self) -> Option<&P> {
50        match self {
51            NlPayload::Payload(ref p) => Some(p),
52            _ => None,
53        }
54    }
55}
56
57impl<'a, T, P> FromBytesWithInput<'a> for NlPayload<T, P>
58where
59    P: FromBytesWithInput<'a, Input = usize>,
60    T: NlType,
61{
62    type Input = (usize, T);
63
64    fn from_bytes_with_input(
65        buffer: &mut Cursor<&'a [u8]>,
66        (input_size, input_type): (usize, T),
67    ) -> Result<Self, DeError> {
68        trace!("Deserializing data type {}", type_name::<Self>());
69        let ty_const: u16 = input_type.into();
70        if ty_const == Nlmsg::Done.into() {
71            trace!("Received empty payload");
72            let mut bytes = Vec::new();
73            buffer.read_to_end(&mut bytes)?;
74            trace!("Padding: {:?}", bytes);
75            Ok(NlPayload::Empty)
76        } else if ty_const == Nlmsg::Error.into() {
77            trace!(
78                "Deserializing field type {}",
79                std::any::type_name::<libc::c_int>()
80            );
81            let code = libc::c_int::from_bytes(buffer)?;
82            trace!("Field deserialized: {:?}", code);
83            if code == 0 {
84                Ok(NlPayload::Ack(Nlmsgerr {
85                    error: code,
86                    nlmsg: {
87                        trace!(
88                            "Deserializing field type {}",
89                            std::any::type_name::<NlmsghdrErr<T, ()>>()
90                        );
91                        trace!("Input: {:?}", input_size);
92                        let ok = NlmsghdrErr::<T, ()>::from_bytes_with_input(
93                            buffer,
94                            input_size - libc::c_int::type_size(),
95                        )?;
96                        trace!("Field deserialized: {:?}", ok);
97                        ok
98                    },
99                }))
100            } else {
101                Ok(NlPayload::Err(Nlmsgerr {
102                    error: code,
103                    nlmsg: {
104                        trace!(
105                            "Deserializing field type {}",
106                            std::any::type_name::<NlmsghdrErr<T, ()>>()
107                        );
108                        trace!("Input: {:?}", input_size);
109                        let ok = NlmsghdrErr::<T, P>::from_bytes_with_input(
110                            buffer,
111                            input_size - libc::c_int::type_size(),
112                        )?;
113                        trace!("Field deserialized: {:?}", ok);
114                        ok
115                    },
116                }))
117            }
118        } else {
119            Ok(NlPayload::Payload(P::from_bytes_with_input(
120                buffer, input_size,
121            )?))
122        }
123    }
124}
125
126/// Top level netlink header and payload
127#[derive(Debug, PartialEq, Eq, Size, ToBytes, FromBytes, Header)]
128#[neli(header_bound = "T: TypeSize")]
129#[neli(from_bytes_bound = "T: NlType")]
130#[neli(from_bytes_bound = "P: FromBytesWithInput<Input = usize>")]
131#[neli(padding)]
132pub struct Nlmsghdr<T, P> {
133    /// Length of the netlink message
134    pub nl_len: u32,
135    /// Type of the netlink message
136    pub nl_type: T,
137    /// Flags indicating properties of the request or response
138    pub nl_flags: NlmFFlags,
139    /// Sequence number for netlink protocol
140    pub nl_seq: u32,
141    /// ID of the netlink destination for requests and source for
142    /// responses.
143    pub nl_pid: u32,
144    /// Payload of netlink message
145    #[neli(input = "(nl_len as usize - Self::header_size() as usize, nl_type)")]
146    #[neli(size = "nl_len as usize - Self::header_size() as usize")]
147    pub nl_payload: NlPayload<T, P>,
148}
149
150impl<T, P> Nlmsghdr<T, P>
151where
152    T: NlType,
153    P: Size,
154{
155    /// Create a new top level netlink packet with a payload.
156    pub fn new(
157        nl_len: Option<u32>,
158        nl_type: T,
159        nl_flags: NlmFFlags,
160        nl_seq: Option<u32>,
161        nl_pid: Option<u32>,
162        nl_payload: NlPayload<T, P>,
163    ) -> Self {
164        let mut nl = Nlmsghdr {
165            nl_len: 0,
166            nl_type,
167            nl_flags,
168            nl_seq: nl_seq.unwrap_or(0),
169            nl_pid: nl_pid.unwrap_or(0),
170            nl_payload,
171        };
172        nl.nl_len = nl_len.unwrap_or(nl.padded_size() as u32);
173        nl
174    }
175
176    /// Get the payload if there is one or return an error.
177    pub fn get_payload(&self) -> Result<&P, NlError> {
178        match self.nl_payload {
179            NlPayload::Payload(ref p) => Ok(p),
180            _ => Err(NlError::new("This packet does not have a payload")),
181        }
182    }
183}