neli/router/
synchronous.rs

1use std::{
2    collections::{HashMap, HashSet},
3    iter::once,
4    marker::PhantomData,
5    sync::{
6        Arc,
7        mpsc::{Receiver, Sender, TryRecvError, channel},
8    },
9    thread::spawn,
10};
11
12use log::{error, trace, warn};
13use parking_lot::Mutex;
14
15use crate::{
16    FromBytesWithInput, Size, ToBytes,
17    consts::{
18        genl::{CtrlAttr, CtrlAttrMcastGrp, CtrlCmd, Index},
19        nl::{GenlId, NlType, NlmF, Nlmsg},
20        socket::NlFamily,
21    },
22    err::RouterError,
23    genl::{AttrTypeBuilder, Genlmsghdr, GenlmsghdrBuilder, NlattrBuilder, NoUserHeader},
24    nl::{NlPayload, Nlmsghdr, NlmsghdrBuilder},
25    socket::synchronous::NlSocketHandle,
26    types::{Buffer, GenlBuffer, NlBuffer},
27    utils::{Groups, NetlinkBitArray},
28};
29
30type GenlFamily = Result<
31    NlBuffer<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>,
32    RouterError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>,
33>;
34type Senders =
35    Arc<Mutex<HashMap<u32, Sender<Result<Nlmsghdr<u16, Buffer>, RouterError<u16, Buffer>>>>>>;
36type ConnectReturn<T> = Result<
37    (
38        T,
39        NlRouterReceiverHandle<u16, Genlmsghdr<u8, u16, NoUserHeader>>,
40    ),
41    RouterError<u16, Buffer>,
42>;
43type ProcThreadReturn = (
44    Sender<()>,
45    Receiver<Result<Nlmsghdr<u16, Buffer>, RouterError<u16, Buffer>>>,
46);
47
48/// A high-level handle for sending messages and generating a handle that validates
49/// all of the received messages.
50pub struct NlRouter {
51    socket: Arc<NlSocketHandle>,
52    seq: Mutex<u32>,
53    senders: Senders,
54    exit_sender: Sender<()>,
55}
56
57fn spawn_processing_thread(socket: Arc<NlSocketHandle>, senders: Senders) -> ProcThreadReturn {
58    let (exit_sender, exit_receiver) = channel();
59    let (multicast_sender, multicast_receiver) = channel();
60    spawn(move || {
61        while let Err(TryRecvError::Empty) = exit_receiver.try_recv() {
62            match socket.recv::<u16, Buffer>() {
63                Ok((iter, group)) => {
64                    for msg in iter {
65                        trace!("Message received: {msg:?}");
66                        let mut seqs_to_remove = HashSet::new();
67                        match msg {
68                            Ok(m) => {
69                                let seq = *m.nl_seq();
70                                let lock = senders.lock();
71                                if !group.is_empty() {
72                                    if multicast_sender.send(Ok(m)).is_err() {
73                                        warn!("{}", RouterError::<u16, Buffer>::ClosedChannel);
74                                    }
75                                } else if let Some(sender) = lock.get(m.nl_seq()) {
76                                    if &socket.pid() == m.nl_pid() {
77                                        if sender.send(Ok(m)).is_err() {
78                                            error!("{}", RouterError::<u16, Buffer>::ClosedChannel);
79                                            seqs_to_remove.insert(seq);
80                                        }
81                                    } else {
82                                        for (seq, sender) in lock.iter() {
83                                            if sender
84                                                .send(Err(RouterError::BadSeqOrPid(m.clone())))
85                                                .is_err()
86                                            {
87                                                error!(
88                                                    "{}",
89                                                    RouterError::<u16, Buffer>::ClosedChannel
90                                                );
91                                                seqs_to_remove.insert(*seq);
92                                            }
93                                        }
94                                    }
95                                } else {
96                                    for (seq, sender) in lock.iter() {
97                                        if sender
98                                            .send(Err(RouterError::BadSeqOrPid(m.clone())))
99                                            .is_err()
100                                        {
101                                            error!("{}", RouterError::<u16, Buffer>::ClosedChannel);
102                                            seqs_to_remove.insert(*seq);
103                                        }
104                                    }
105                                }
106                            }
107                            Err(e) => {
108                                let lock = senders.lock();
109                                for (seq, sender) in lock.iter() {
110                                    if sender.send(Err(RouterError::from(e.clone()))).is_err() {
111                                        error!("{}", RouterError::<u16, Buffer>::ClosedChannel);
112                                        seqs_to_remove.insert(*seq);
113                                    }
114                                }
115                            }
116                        }
117                        for seq in seqs_to_remove {
118                            senders.lock().remove(&seq);
119                        }
120                    }
121                }
122                Err(e) => {
123                    let mut seqs_to_remove = HashSet::new();
124                    let mut lock = senders.lock();
125                    for (seq, sender) in lock.iter() {
126                        if sender.send(Err(RouterError::from(e.clone()))).is_err() {
127                            seqs_to_remove.insert(*seq);
128                            error!("{}", RouterError::<u16, Buffer>::ClosedChannel);
129                            break;
130                        }
131                    }
132                    for seq in seqs_to_remove {
133                        lock.remove(&seq);
134                    }
135                }
136            }
137        }
138    });
139    (exit_sender, multicast_receiver)
140}
141
142impl NlRouter {
143    /// Equivalent of `socket` and `bind` calls.
144    pub fn connect(proto: NlFamily, pid: Option<u32>, groups: Groups) -> ConnectReturn<Self> {
145        let socket = Arc::new(NlSocketHandle::connect(proto, pid, groups)?);
146        let senders = Arc::new(Mutex::new(HashMap::default()));
147        let (exit_sender, multicast_receiver) =
148            spawn_processing_thread(Arc::clone(&socket), Arc::clone(&senders));
149        let multicast_receiver =
150            NlRouterReceiverHandle::new(multicast_receiver, Arc::clone(&senders), false, None);
151        Ok((
152            NlRouter {
153                socket,
154                senders,
155                seq: Mutex::new(0),
156                exit_sender,
157            },
158            multicast_receiver,
159        ))
160    }
161
162    /// Join multicast groups for a socket.
163    pub fn add_mcast_membership(&self, groups: Groups) -> Result<(), RouterError<u16, Buffer>> {
164        self.socket
165            .add_mcast_membership(groups)
166            .map_err(RouterError::from)
167    }
168
169    /// Leave multicast groups for a socket.
170    pub fn drop_mcast_membership(&self, groups: Groups) -> Result<(), RouterError<u16, Buffer>> {
171        self.socket
172            .drop_mcast_membership(groups)
173            .map_err(RouterError::from)
174    }
175
176    /// List joined groups for a socket.
177    pub fn list_mcast_membership(&self) -> Result<NetlinkBitArray, RouterError<u16, Buffer>> {
178        self.socket
179            .list_mcast_membership()
180            .map_err(RouterError::from)
181    }
182
183    /// If [`true`] is passed in, enable extended ACKs for this socket. If [`false`]
184    /// is passed in, disable extended ACKs for this socket.
185    pub fn enable_ext_ack(&self, enable: bool) -> Result<(), RouterError<u16, Buffer>> {
186        self.socket
187            .enable_ext_ack(enable)
188            .map_err(RouterError::from)
189    }
190
191    /// Return [`true`] if an extended ACK is enabled for this socket.
192    pub fn get_ext_ack_enabled(&self) -> Result<bool, RouterError<u16, Buffer>> {
193        self.socket.get_ext_ack_enabled().map_err(RouterError::from)
194    }
195
196    /// If [`true`] is passed in, enable strict checking for this socket. If [`false`]
197    /// is passed in, disable strict checking for for this socket.
198    /// Only supported by `NlFamily::Route` sockets.
199    /// Requires Linux >= 4.20.
200    pub fn enable_strict_checking(&self, enable: bool) -> Result<(), RouterError<u16, Buffer>> {
201        self.socket
202            .enable_strict_checking(enable)
203            .map_err(RouterError::from)
204    }
205
206    /// Return [`true`] if strict checking is enabled for this socket.
207    /// Only supported by `NlFamily::Route` sockets.
208    /// Requires Linux >= 4.20.
209    pub fn get_strict_checking_enabled(&self) -> Result<bool, RouterError<u16, Buffer>> {
210        self.socket
211            .get_strict_checking_enabled()
212            .map_err(RouterError::from)
213    }
214
215    /// Get the PID for the current socket.
216    pub fn pid(&self) -> u32 {
217        self.socket.pid()
218    }
219
220    fn next_seq(&self) -> u32 {
221        let mut lock = self.seq.lock();
222        let next = *lock;
223        *lock = lock.wrapping_add(1);
224        next
225    }
226
227    /// Send a message and return a handle for receiving responses from this message.
228    pub fn send<ST, SP, RT, RP>(
229        &self,
230        nl_type: ST,
231        nl_flags: NlmF,
232        nl_payload: NlPayload<ST, SP>,
233    ) -> Result<NlRouterReceiverHandle<RT, RP>, RouterError<ST, SP>>
234    where
235        ST: NlType,
236        SP: Size + ToBytes,
237    {
238        let msg = NlmsghdrBuilder::default()
239            .nl_type(nl_type)
240            .nl_flags(
241                // Required for messages
242                nl_flags | NlmF::REQUEST,
243            )
244            .nl_pid(self.socket.pid())
245            .nl_seq(self.next_seq())
246            .nl_payload(nl_payload)
247            .build()?;
248
249        let (sender, receiver) = channel();
250        let seq = *msg.nl_seq();
251        self.senders.lock().insert(seq, sender);
252        let flags = *msg.nl_flags();
253
254        self.socket.send(&msg)?;
255
256        Ok(NlRouterReceiverHandle::new(
257            receiver,
258            Arc::clone(&self.senders),
259            flags.contains(NlmF::ACK) && !flags.contains(NlmF::DUMP),
260            Some(seq),
261        ))
262    }
263
264    fn get_genl_family(&self, family_name: &str) -> GenlFamily {
265        let recv = self.send(
266            GenlId::Ctrl,
267            NlmF::ACK,
268            NlPayload::Payload(
269                GenlmsghdrBuilder::default()
270                    .cmd(CtrlCmd::Getfamily)
271                    .version(2)
272                    .attrs(
273                        once(
274                            NlattrBuilder::default()
275                                .nla_type(
276                                    AttrTypeBuilder::default()
277                                        .nla_type(CtrlAttr::FamilyName)
278                                        .build()?,
279                                )
280                                .nla_payload(family_name)
281                                .build()?,
282                        )
283                        .collect::<GenlBuffer<_, _>>(),
284                    )
285                    .build()?,
286            ),
287        )?;
288
289        let mut buffer = NlBuffer::new();
290        for msg in recv {
291            buffer.push(msg?);
292        }
293        Ok(buffer)
294    }
295
296    /// Convenience function for resolving a [`str`] containing the
297    /// generic netlink family name to a numeric generic netlink ID.
298    pub fn resolve_genl_family(
299        &self,
300        family_name: &str,
301    ) -> Result<u16, RouterError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>> {
302        let mut res = Err(RouterError::new(format!(
303            "Generic netlink family {family_name} was not found"
304        )));
305
306        let nlhdrs = self.get_genl_family(family_name)?;
307        for nlhdr in nlhdrs.into_iter() {
308            if let NlPayload::Payload(p) = nlhdr.nl_payload() {
309                let handle = p.attrs().get_attr_handle();
310                if let Ok(u) = handle.get_attr_payload_as::<u16>(CtrlAttr::FamilyId) {
311                    res = Ok(u);
312                }
313            }
314        }
315
316        res
317    }
318
319    /// Convenience function for resolving a [`str`] containing the
320    /// multicast group name to a numeric multicast group ID.
321    pub fn resolve_nl_mcast_group(
322        &self,
323        family_name: &str,
324        mcast_name: &str,
325    ) -> Result<u32, RouterError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>> {
326        let mut res = Err(RouterError::new(format!(
327            "Failed to resolve multicast group ID for family name {family_name}, multicast group name {mcast_name}"
328        )));
329
330        let nlhdrs = self.get_genl_family(family_name)?;
331        for nlhdr in nlhdrs {
332            if let NlPayload::Payload(p) = nlhdr.nl_payload() {
333                let handle = p.attrs().get_attr_handle();
334                let mcast_groups = handle.get_nested_attributes::<Index>(CtrlAttr::McastGroups)?;
335                if let Some(id) = mcast_groups.iter().find_map(|item| {
336                    let nested_attrs = item.get_attr_handle::<CtrlAttrMcastGrp>().ok()?;
337                    let string = nested_attrs
338                        .get_attr_payload_as_with_len::<String>(CtrlAttrMcastGrp::Name)
339                        .ok()?;
340                    if string.as_str() == mcast_name {
341                        nested_attrs
342                            .get_attr_payload_as::<u32>(CtrlAttrMcastGrp::Id)
343                            .ok()
344                    } else {
345                        None
346                    }
347                }) {
348                    res = Ok(id);
349                }
350            }
351        }
352
353        res
354    }
355
356    /// Look up netlink family and multicast group name by ID.
357    pub fn lookup_id(
358        &self,
359        id: u32,
360    ) -> Result<(String, String), RouterError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>> {
361        let mut res = Err(RouterError::new(
362            "ID does not correspond to a multicast group",
363        ));
364
365        let recv = self.send(
366            GenlId::Ctrl,
367            NlmF::DUMP,
368            NlPayload::Payload(
369                GenlmsghdrBuilder::<CtrlCmd, CtrlAttr, NoUserHeader>::default()
370                    .cmd(CtrlCmd::Getfamily)
371                    .version(2)
372                    .attrs(GenlBuffer::new())
373                    .build()?,
374            ),
375        )?;
376        for res_msg in recv {
377            let msg = res_msg?;
378
379            if let NlPayload::Payload(p) = msg.nl_payload() {
380                let attributes = p.attrs().get_attr_handle();
381                let name =
382                    attributes.get_attr_payload_as_with_len::<String>(CtrlAttr::FamilyName)?;
383                let groups = match attributes.get_nested_attributes::<Index>(CtrlAttr::McastGroups)
384                {
385                    Ok(grps) => grps,
386                    Err(_) => continue,
387                };
388                for group_by_index in groups.iter() {
389                    let attributes = group_by_index.get_attr_handle::<CtrlAttrMcastGrp>()?;
390                    if let Ok(mcid) = attributes.get_attr_payload_as::<u32>(CtrlAttrMcastGrp::Id) {
391                        if mcid == id {
392                            let mcast_name = attributes
393                                .get_attr_payload_as_with_len::<String>(CtrlAttrMcastGrp::Name)?;
394                            res = Ok((name.clone(), mcast_name));
395                        }
396                    }
397                }
398            }
399        }
400
401        res
402    }
403}
404
405impl Drop for NlRouter {
406    fn drop(&mut self) {
407        if self.exit_sender.send(()).is_err() {
408            warn!("Failed to send shutdown message; processing thread should exit anyway");
409        }
410    }
411}
412
413/// A handle for receiving and validating all messages that correspond to a request.
414pub struct NlRouterReceiverHandle<T, P> {
415    receiver: Receiver<Result<Nlmsghdr<u16, Buffer>, RouterError<u16, Buffer>>>,
416    senders: Senders,
417    needs_ack: bool,
418    seq: Option<u32>,
419    next_is_none: bool,
420    next_is_ack: bool,
421    data: PhantomData<(T, P)>,
422}
423
424impl<T, P> NlRouterReceiverHandle<T, P> {
425    fn new(
426        receiver: Receiver<Result<Nlmsghdr<u16, Buffer>, RouterError<u16, Buffer>>>,
427        senders: Senders,
428        needs_ack: bool,
429        seq: Option<u32>,
430    ) -> Self {
431        NlRouterReceiverHandle {
432            receiver,
433            senders,
434            needs_ack,
435            seq,
436            next_is_none: false,
437            next_is_ack: false,
438            data: PhantomData,
439        }
440    }
441}
442
443impl<T, P> NlRouterReceiverHandle<T, P>
444where
445    T: NlType,
446    P: Size + FromBytesWithInput<Input = usize>,
447{
448    /// Imitates the [`Iterator`] API but allows parsing differently typed
449    /// messages in a sequence of messages meant for this receiver.
450    pub fn next_typed<TT, PP>(&mut self) -> Option<Result<Nlmsghdr<TT, PP>, RouterError<TT, PP>>>
451    where
452        TT: NlType,
453        PP: Size + FromBytesWithInput<Input = usize>,
454    {
455        if self.next_is_none {
456            return None;
457        }
458
459        let mut msg = match self.receiver.recv() {
460            Ok(untyped) => match untyped {
461                Ok(u) => match u.to_typed::<TT, PP>() {
462                    Ok(t) => t,
463                    Err(e) => {
464                        self.next_is_none = true;
465                        return Some(Err(e));
466                    }
467                },
468                Err(e) => {
469                    self.next_is_none = true;
470                    return Some(Err(match e.to_typed() {
471                        Ok(e) => e,
472                        Err(e) => e,
473                    }));
474                }
475            },
476            Err(_) => {
477                self.next_is_none = true;
478                return Some(Err(RouterError::ClosedChannel));
479            }
480        };
481
482        let nl_type = Nlmsg::from((*msg.nl_type()).into());
483        if let NlPayload::Ack(_) = msg.nl_payload() {
484            self.next_is_none = true;
485            if !self.needs_ack {
486                return Some(Err(RouterError::UnexpectedAck));
487            }
488        } else if let Some(e) = msg.get_err() {
489            self.next_is_none = true;
490            if self.next_is_ack {
491                return Some(Err(RouterError::NoAck));
492            } else {
493                return Some(Err(RouterError::<TT, PP>::Nlmsgerr(e)));
494            }
495        } else if (!msg.nl_flags().contains(NlmF::MULTI) || nl_type == Nlmsg::Done)
496            && self.seq.is_some()
497        {
498            assert!(!self.next_is_ack);
499
500            if self.needs_ack {
501                self.next_is_ack = true;
502            } else {
503                self.next_is_none = true;
504            }
505        } else if self.next_is_ack {
506            self.next_is_none = true;
507            return Some(Err(RouterError::NoAck));
508        }
509
510        trace!("Router received message: {msg:?}");
511
512        Some(Ok(msg))
513    }
514}
515
516impl<T, P> Iterator for NlRouterReceiverHandle<T, P>
517where
518    T: NlType,
519    P: Size + FromBytesWithInput<Input = usize>,
520{
521    type Item = Result<Nlmsghdr<T, P>, RouterError<T, P>>;
522
523    fn next(&mut self) -> Option<Self::Item> {
524        self.next_typed::<T, P>()
525    }
526}
527
528impl<T, P> Drop for NlRouterReceiverHandle<T, P> {
529    fn drop(&mut self) {
530        if let Some(seq) = self.seq {
531            self.senders.lock().remove(&seq);
532        }
533    }
534}
535
536#[cfg(test)]
537mod test {
538    use super::*;
539
540    use crate::test::setup;
541
542    #[test]
543    fn real_test_mcast_groups() {
544        setup();
545
546        let (sock, _multicast) =
547            NlRouter::connect(NlFamily::Generic, None, Groups::empty()).unwrap();
548        sock.enable_strict_checking(true).unwrap();
549        let notify_id_result = sock.resolve_nl_mcast_group("nlctrl", "notify");
550        let config_id_result = sock.resolve_nl_mcast_group("devlink", "config");
551
552        let ids = match (notify_id_result, config_id_result) {
553            (Ok(ni), Ok(ci)) => {
554                sock.add_mcast_membership(Groups::new_groups(&[ni, ci]))
555                    .unwrap();
556                vec![ni, ci]
557            }
558            (Ok(ni), Err(RouterError::Nlmsgerr(_))) => {
559                sock.add_mcast_membership(Groups::new_groups(&[ni]))
560                    .unwrap();
561                vec![ni]
562            }
563            (Err(RouterError::Nlmsgerr(_)), Ok(ci)) => {
564                sock.add_mcast_membership(Groups::new_groups(&[ci]))
565                    .unwrap();
566                vec![ci]
567            }
568            (Err(RouterError::Nlmsgerr(_)), Err(RouterError::Nlmsgerr(_))) => {
569                return;
570            }
571            (Err(e), _) => panic!("Unexpected result from resolve_nl_mcast_group: {e:?}"),
572            (_, Err(e)) => panic!("Unexpected result from resolve_nl_mcast_group: {e:?}"),
573        };
574
575        let groups = sock.list_mcast_membership().unwrap();
576        for id in ids.iter() {
577            assert!(groups.is_set(*id as usize));
578        }
579
580        sock.drop_mcast_membership(Groups::new_groups(ids.as_slice()))
581            .unwrap();
582        let groups = sock.list_mcast_membership().unwrap();
583
584        for id in ids.iter() {
585            assert!(!groups.is_set(*id as usize));
586        }
587    }
588}