1use std::{
16 io::Cursor,
17 mem::{size_of, swap},
18};
19
20use derive_builder::{Builder, UninitializedFieldError};
21use getset::Getters;
22use log::trace;
23
24use crate::{
25 self as neli, FromBytes, FromBytesWithInput, Header, Size, ToBytes, TypeSize,
26 consts::nl::{NlType, NlmF, Nlmsg},
27 err::{DeError, Nlmsgerr, NlmsgerrBuilder, NlmsghdrAck, NlmsghdrErr, RouterError},
28 types::{Buffer, GenlBuffer},
29};
30
31#[derive(Clone, Debug, PartialEq, Eq, Size, ToBytes)]
35pub enum NlPayload<T, P> {
36 Ack(Nlmsgerr<NlmsghdrAck<T>>),
38 DumpExtAck(Nlmsgerr<()>),
41 Err(Nlmsgerr<NlmsghdrErr<T, P>>),
43 Payload(P),
45 Empty,
47}
48
49impl<T, P> FromBytesWithInput for NlPayload<T, P>
50where
51 P: Size + FromBytesWithInput<Input = usize>,
52 T: NlType,
53{
54 type Input = (usize, T, NlmF);
55
56 fn from_bytes_with_input(
57 buffer: &mut Cursor<impl AsRef<[u8]>>,
58 (input_size, input_type, flags): (usize, T, NlmF),
59 ) -> Result<Self, DeError> {
60 let pos = buffer.position();
61
62 let mut processing = || {
63 trace!("Deserializing data type {}", std::any::type_name::<Self>());
64 let ty_const: u16 = input_type.into();
65 if ty_const == Nlmsg::Done.into() {
66 if buffer.position() == buffer.get_ref().as_ref().len() as u64 {
67 Ok(NlPayload::Empty)
68 } else if flags.contains(NlmF::MULTI) {
69 trace!(
70 "Deserializing field type {}",
71 std::any::type_name::<Nlmsgerr<()>>(),
72 );
73 trace!("Input: {input_size:?}");
74 let ext = Nlmsgerr::from_bytes_with_input(buffer, input_size)?;
75 Ok(NlPayload::DumpExtAck(ext))
76 } else {
77 Ok(NlPayload::Payload(P::from_bytes_with_input(
80 buffer, input_size,
81 )?))
82 }
83 } else if ty_const == Nlmsg::Error.into() {
84 trace!(
85 "Deserializing field type {}",
86 std::any::type_name::<libc::c_int>()
87 );
88 let code = libc::c_int::from_bytes(buffer)?;
89 trace!("Field deserialized: {code:?}");
90 if code == 0 {
91 trace!(
92 "Deserializing field type {}",
93 std::any::type_name::<NlmsghdrErr<T, ()>>()
94 );
95 trace!("Input: {input_size:?}");
96 let nlmsg = NlmsghdrAck::<T>::from_bytes(buffer)?;
97 trace!("Field deserialized: {nlmsg:?}");
98 Ok(NlPayload::Ack(
99 NlmsgerrBuilder::default().nlmsg(nlmsg).build()?,
100 ))
101 } else {
102 trace!(
103 "Deserializing field type {}",
104 std::any::type_name::<NlmsghdrErr<T, ()>>()
105 );
106 let nlmsg = NlmsghdrErr::<T, P>::from_bytes(buffer)?;
107 trace!("Field deserialized: {nlmsg:?}");
108
109 trace!(
110 "Deserializing field type {}",
111 std::any::type_name::<GenlBuffer<u16, Buffer>>()
112 );
113 let input = input_size - size_of::<libc::c_int>() - nlmsg.padded_size();
114 trace!("Input: {input:?}");
115 let ext_ack = GenlBuffer::from_bytes_with_input(buffer, input)?;
116 trace!("Field deserialized: {ext_ack:?}");
117
118 Ok(NlPayload::Err(
119 NlmsgerrBuilder::default()
120 .error(code)
121 .nlmsg(nlmsg)
122 .ext_ack(ext_ack)
123 .build()?,
124 ))
125 }
126 } else {
127 Ok(NlPayload::Payload(P::from_bytes_with_input(
128 buffer, input_size,
129 )?))
130 }
131 };
132
133 match processing() {
134 Ok(o) => Ok(o),
135 Err(e) => {
136 buffer.set_position(pos);
137 Err(e)
138 }
139 }
140 }
141}
142
143#[derive(Builder, Getters, Clone, Debug, PartialEq, Eq, Size, ToBytes, FromBytes, Header)]
145#[neli(header_bound = "T: TypeSize")]
146#[neli(from_bytes_bound = "T: NlType")]
147#[neli(from_bytes_bound = "P: Size + FromBytesWithInput<Input = usize>")]
148#[neli(padding)]
149#[builder(build_fn(skip))]
150#[builder(pattern = "owned")]
151pub struct Nlmsghdr<T, P> {
152 #[builder(setter(skip))]
154 #[getset(get = "pub")]
155 nl_len: u32,
156 #[getset(get = "pub")]
158 nl_type: T,
159 #[getset(get = "pub")]
161 nl_flags: NlmF,
162 #[getset(get = "pub")]
164 nl_seq: u32,
165 #[getset(get = "pub")]
168 nl_pid: u32,
169 #[neli(input = "(nl_len as usize - Self::header_size() as usize, nl_type, nl_flags)")]
171 #[neli(size = "nl_len as usize - Self::header_size() as usize")]
172 #[getset(get = "pub")]
173 pub(crate) nl_payload: NlPayload<T, P>,
174}
175
176impl<T, P> NlmsghdrBuilder<T, P>
177where
178 T: NlType,
179 P: Size,
180{
181 pub fn build(self) -> Result<Nlmsghdr<T, P>, NlmsghdrBuilderError> {
183 let nl_type = self
184 .nl_type
185 .ok_or_else(|| NlmsghdrBuilderError::from(UninitializedFieldError::new("nl_type")))?;
186 let nl_flags = self
187 .nl_flags
188 .ok_or_else(|| NlmsghdrBuilderError::from(UninitializedFieldError::new("nl_flags")))?;
189 let nl_seq = self.nl_seq.unwrap_or(0);
190 let nl_pid = self.nl_pid.unwrap_or(0);
191 let nl_payload = self.nl_payload.ok_or_else(|| {
192 NlmsghdrBuilderError::from(UninitializedFieldError::new("nl_payload"))
193 })?;
194
195 let mut nl = Nlmsghdr {
196 nl_len: 0,
197 nl_type,
198 nl_flags,
199 nl_seq,
200 nl_pid,
201 nl_payload,
202 };
203 nl.nl_len = nl.padded_size() as u32;
204 Ok(nl)
205 }
206}
207
208impl<T, P> Nlmsghdr<T, P>
209where
210 T: NlType,
211{
212 pub fn get_payload(&self) -> Option<&P> {
214 match self.nl_payload {
215 NlPayload::Payload(ref p) => Some(p),
216 _ => None,
217 }
218 }
219
220 pub fn get_err(&mut self) -> Option<Nlmsgerr<NlmsghdrErr<T, P>>> {
225 match self.nl_payload {
226 NlPayload::Err(_) => {
227 let mut payload = NlPayload::Empty;
228 swap(&mut self.nl_payload, &mut payload);
229 match payload {
230 NlPayload::Err(e) => Some(e),
231 _ => unreachable!(),
232 }
233 }
234 _ => None,
235 }
236 }
237}
238
239impl NlPayload<u16, Buffer> {
240 pub fn to_typed<T, P>(self, payload_size: usize) -> Result<NlPayload<T, P>, RouterError<T, P>>
242 where
243 T: NlType,
244 P: Size + FromBytesWithInput<Input = usize>,
245 {
246 match self {
247 NlPayload::Ack(a) => Ok(NlPayload::Ack(a.to_typed()?)),
248 NlPayload::Err(e) => Ok(NlPayload::Err(e.to_typed()?)),
249 NlPayload::DumpExtAck(a) => Ok(NlPayload::DumpExtAck(a)),
250 NlPayload::Payload(p) => Ok(NlPayload::Payload(P::from_bytes_with_input(
251 &mut Cursor::new(p),
252 payload_size,
253 )?)),
254 NlPayload::Empty => Ok(NlPayload::Empty),
255 }
256 }
257}
258
259impl<T, P> Nlmsghdr<T, P>
260where
261 T: NlType,
262 P: Size,
263{
264 pub fn set_payload(&mut self, p: NlPayload<T, P>) {
266 self.nl_len -= self.nl_payload.padded_size() as u32;
267 self.nl_len += p.padded_size() as u32;
268 self.nl_payload = p;
269 }
270}
271
272impl Nlmsghdr<u16, Buffer> {
273 pub fn to_typed<T, P>(self) -> Result<Nlmsghdr<T, P>, RouterError<T, P>>
275 where
276 T: NlType,
277 P: Size + FromBytesWithInput<Input = usize>,
278 {
279 Ok(NlmsghdrBuilder::default()
280 .nl_type(T::from(self.nl_type))
281 .nl_flags(self.nl_flags)
282 .nl_seq(self.nl_seq)
283 .nl_pid(self.nl_pid)
284 .nl_payload(
285 self.nl_payload
286 .to_typed::<T, P>(self.nl_len as usize - Self::header_size())?,
287 )
288 .build()?)
289 }
290}