1use std::{
2 collections::{HashMap, HashSet},
3 iter::once,
4 marker::PhantomData,
5 sync::{
6 Arc,
7 mpsc::{Receiver, Sender, TryRecvError, channel},
8 },
9 thread::spawn,
10};
11
12use log::{error, trace, warn};
13use parking_lot::Mutex;
14
15use crate::{
16 FromBytesWithInput, Size, ToBytes,
17 consts::{
18 genl::{CtrlAttr, CtrlAttrMcastGrp, CtrlCmd, Index},
19 nl::{GenlId, NlType, NlmF, Nlmsg},
20 socket::NlFamily,
21 },
22 err::RouterError,
23 genl::{AttrTypeBuilder, Genlmsghdr, GenlmsghdrBuilder, NlattrBuilder, NoUserHeader},
24 nl::{NlPayload, Nlmsghdr, NlmsghdrBuilder},
25 socket::synchronous::NlSocketHandle,
26 types::{Buffer, GenlBuffer, NlBuffer},
27 utils::{Groups, NetlinkBitArray},
28};
29
30type GenlFamily = Result<
31 NlBuffer<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>,
32 RouterError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>,
33>;
34type Senders =
35 Arc<Mutex<HashMap<u32, Sender<Result<Nlmsghdr<u16, Buffer>, RouterError<u16, Buffer>>>>>>;
36type ConnectReturn<T> = Result<
37 (
38 T,
39 NlRouterReceiverHandle<u16, Genlmsghdr<u8, u16, NoUserHeader>>,
40 ),
41 RouterError<u16, Buffer>,
42>;
43type ProcThreadReturn = (
44 Sender<()>,
45 Receiver<Result<Nlmsghdr<u16, Buffer>, RouterError<u16, Buffer>>>,
46);
47
48pub struct NlRouter {
51 socket: Arc<NlSocketHandle>,
52 seq: Mutex<u32>,
53 senders: Senders,
54 exit_sender: Sender<()>,
55}
56
57fn spawn_processing_thread(socket: Arc<NlSocketHandle>, senders: Senders) -> ProcThreadReturn {
58 let (exit_sender, exit_receiver) = channel();
59 let (multicast_sender, multicast_receiver) = channel();
60 spawn(move || {
61 while let Err(TryRecvError::Empty) = exit_receiver.try_recv() {
62 match socket.recv::<u16, Buffer>() {
63 Ok((iter, group)) => {
64 for msg in iter {
65 trace!("Message received: {msg:?}");
66 let mut seqs_to_remove = HashSet::new();
67 match msg {
68 Ok(m) => {
69 let seq = *m.nl_seq();
70 let lock = senders.lock();
71 if !group.is_empty() {
72 if multicast_sender.send(Ok(m)).is_err() {
73 warn!("{}", RouterError::<u16, Buffer>::ClosedChannel);
74 }
75 } else if let Some(sender) = lock.get(m.nl_seq()) {
76 if &socket.pid() == m.nl_pid() {
77 if sender.send(Ok(m)).is_err() {
78 error!("{}", RouterError::<u16, Buffer>::ClosedChannel);
79 seqs_to_remove.insert(seq);
80 }
81 } else {
82 for (seq, sender) in lock.iter() {
83 if sender
84 .send(Err(RouterError::BadSeqOrPid(m.clone())))
85 .is_err()
86 {
87 error!(
88 "{}",
89 RouterError::<u16, Buffer>::ClosedChannel
90 );
91 seqs_to_remove.insert(*seq);
92 }
93 }
94 }
95 } else {
96 for (seq, sender) in lock.iter() {
97 if sender
98 .send(Err(RouterError::BadSeqOrPid(m.clone())))
99 .is_err()
100 {
101 error!("{}", RouterError::<u16, Buffer>::ClosedChannel);
102 seqs_to_remove.insert(*seq);
103 }
104 }
105 }
106 }
107 Err(e) => {
108 let lock = senders.lock();
109 for (seq, sender) in lock.iter() {
110 if sender.send(Err(RouterError::from(e.clone()))).is_err() {
111 error!("{}", RouterError::<u16, Buffer>::ClosedChannel);
112 seqs_to_remove.insert(*seq);
113 }
114 }
115 }
116 }
117 for seq in seqs_to_remove {
118 senders.lock().remove(&seq);
119 }
120 }
121 }
122 Err(e) => {
123 let mut seqs_to_remove = HashSet::new();
124 let mut lock = senders.lock();
125 for (seq, sender) in lock.iter() {
126 if sender.send(Err(RouterError::from(e.clone()))).is_err() {
127 seqs_to_remove.insert(*seq);
128 error!("{}", RouterError::<u16, Buffer>::ClosedChannel);
129 break;
130 }
131 }
132 for seq in seqs_to_remove {
133 lock.remove(&seq);
134 }
135 }
136 }
137 }
138 });
139 (exit_sender, multicast_receiver)
140}
141
142impl NlRouter {
143 pub fn connect(proto: NlFamily, pid: Option<u32>, groups: Groups) -> ConnectReturn<Self> {
145 let socket = Arc::new(NlSocketHandle::connect(proto, pid, groups)?);
146 let senders = Arc::new(Mutex::new(HashMap::default()));
147 let (exit_sender, multicast_receiver) =
148 spawn_processing_thread(Arc::clone(&socket), Arc::clone(&senders));
149 let multicast_receiver =
150 NlRouterReceiverHandle::new(multicast_receiver, Arc::clone(&senders), false, None);
151 Ok((
152 NlRouter {
153 socket,
154 senders,
155 seq: Mutex::new(0),
156 exit_sender,
157 },
158 multicast_receiver,
159 ))
160 }
161
162 pub fn add_mcast_membership(&self, groups: Groups) -> Result<(), RouterError<u16, Buffer>> {
164 self.socket
165 .add_mcast_membership(groups)
166 .map_err(RouterError::from)
167 }
168
169 pub fn drop_mcast_membership(&self, groups: Groups) -> Result<(), RouterError<u16, Buffer>> {
171 self.socket
172 .drop_mcast_membership(groups)
173 .map_err(RouterError::from)
174 }
175
176 pub fn list_mcast_membership(&self) -> Result<NetlinkBitArray, RouterError<u16, Buffer>> {
178 self.socket
179 .list_mcast_membership()
180 .map_err(RouterError::from)
181 }
182
183 pub fn enable_ext_ack(&self, enable: bool) -> Result<(), RouterError<u16, Buffer>> {
186 self.socket
187 .enable_ext_ack(enable)
188 .map_err(RouterError::from)
189 }
190
191 pub fn get_ext_ack_enabled(&self) -> Result<bool, RouterError<u16, Buffer>> {
193 self.socket.get_ext_ack_enabled().map_err(RouterError::from)
194 }
195
196 pub fn enable_strict_checking(&self, enable: bool) -> Result<(), RouterError<u16, Buffer>> {
201 self.socket
202 .enable_strict_checking(enable)
203 .map_err(RouterError::from)
204 }
205
206 pub fn get_strict_checking_enabled(&self) -> Result<bool, RouterError<u16, Buffer>> {
210 self.socket
211 .get_strict_checking_enabled()
212 .map_err(RouterError::from)
213 }
214
215 pub fn pid(&self) -> u32 {
217 self.socket.pid()
218 }
219
220 fn next_seq(&self) -> u32 {
221 let mut lock = self.seq.lock();
222 let next = *lock;
223 *lock = lock.wrapping_add(1);
224 next
225 }
226
227 pub fn send<ST, SP, RT, RP>(
229 &self,
230 nl_type: ST,
231 nl_flags: NlmF,
232 nl_payload: NlPayload<ST, SP>,
233 ) -> Result<NlRouterReceiverHandle<RT, RP>, RouterError<ST, SP>>
234 where
235 ST: NlType,
236 SP: Size + ToBytes,
237 {
238 let msg = NlmsghdrBuilder::default()
239 .nl_type(nl_type)
240 .nl_flags(
241 nl_flags | NlmF::REQUEST,
243 )
244 .nl_pid(self.socket.pid())
245 .nl_seq(self.next_seq())
246 .nl_payload(nl_payload)
247 .build()?;
248
249 let (sender, receiver) = channel();
250 let seq = *msg.nl_seq();
251 self.senders.lock().insert(seq, sender);
252 let flags = *msg.nl_flags();
253
254 self.socket.send(&msg)?;
255
256 Ok(NlRouterReceiverHandle::new(
257 receiver,
258 Arc::clone(&self.senders),
259 flags.contains(NlmF::ACK) && !flags.contains(NlmF::DUMP),
260 Some(seq),
261 ))
262 }
263
264 fn get_genl_family(&self, family_name: &str) -> GenlFamily {
265 let recv = self.send(
266 GenlId::Ctrl,
267 NlmF::ACK,
268 NlPayload::Payload(
269 GenlmsghdrBuilder::default()
270 .cmd(CtrlCmd::Getfamily)
271 .version(2)
272 .attrs(
273 once(
274 NlattrBuilder::default()
275 .nla_type(
276 AttrTypeBuilder::default()
277 .nla_type(CtrlAttr::FamilyName)
278 .build()?,
279 )
280 .nla_payload(family_name)
281 .build()?,
282 )
283 .collect::<GenlBuffer<_, _>>(),
284 )
285 .build()?,
286 ),
287 )?;
288
289 let mut buffer = NlBuffer::new();
290 for msg in recv {
291 buffer.push(msg?);
292 }
293 Ok(buffer)
294 }
295
296 pub fn resolve_genl_family(
299 &self,
300 family_name: &str,
301 ) -> Result<u16, RouterError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>> {
302 let mut res = Err(RouterError::new(format!(
303 "Generic netlink family {family_name} was not found"
304 )));
305
306 let nlhdrs = self.get_genl_family(family_name)?;
307 for nlhdr in nlhdrs.into_iter() {
308 if let NlPayload::Payload(p) = nlhdr.nl_payload() {
309 let handle = p.attrs().get_attr_handle();
310 if let Ok(u) = handle.get_attr_payload_as::<u16>(CtrlAttr::FamilyId) {
311 res = Ok(u);
312 }
313 }
314 }
315
316 res
317 }
318
319 pub fn resolve_nl_mcast_group(
322 &self,
323 family_name: &str,
324 mcast_name: &str,
325 ) -> Result<u32, RouterError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>> {
326 let mut res = Err(RouterError::new(format!(
327 "Failed to resolve multicast group ID for family name {family_name}, multicast group name {mcast_name}"
328 )));
329
330 let nlhdrs = self.get_genl_family(family_name)?;
331 for nlhdr in nlhdrs {
332 if let NlPayload::Payload(p) = nlhdr.nl_payload() {
333 let handle = p.attrs().get_attr_handle();
334 let mcast_groups = handle.get_nested_attributes::<Index>(CtrlAttr::McastGroups)?;
335 if let Some(id) = mcast_groups.iter().find_map(|item| {
336 let nested_attrs = item.get_attr_handle::<CtrlAttrMcastGrp>().ok()?;
337 let string = nested_attrs
338 .get_attr_payload_as_with_len::<String>(CtrlAttrMcastGrp::Name)
339 .ok()?;
340 if string.as_str() == mcast_name {
341 nested_attrs
342 .get_attr_payload_as::<u32>(CtrlAttrMcastGrp::Id)
343 .ok()
344 } else {
345 None
346 }
347 }) {
348 res = Ok(id);
349 }
350 }
351 }
352
353 res
354 }
355
356 pub fn lookup_id(
358 &self,
359 id: u32,
360 ) -> Result<(String, String), RouterError<GenlId, Genlmsghdr<CtrlCmd, CtrlAttr>>> {
361 let mut res = Err(RouterError::new(
362 "ID does not correspond to a multicast group",
363 ));
364
365 let recv = self.send(
366 GenlId::Ctrl,
367 NlmF::DUMP,
368 NlPayload::Payload(
369 GenlmsghdrBuilder::<CtrlCmd, CtrlAttr, NoUserHeader>::default()
370 .cmd(CtrlCmd::Getfamily)
371 .version(2)
372 .attrs(GenlBuffer::new())
373 .build()?,
374 ),
375 )?;
376 for res_msg in recv {
377 let msg = res_msg?;
378
379 if let NlPayload::Payload(p) = msg.nl_payload() {
380 let attributes = p.attrs().get_attr_handle();
381 let name =
382 attributes.get_attr_payload_as_with_len::<String>(CtrlAttr::FamilyName)?;
383 let groups = match attributes.get_nested_attributes::<Index>(CtrlAttr::McastGroups)
384 {
385 Ok(grps) => grps,
386 Err(_) => continue,
387 };
388 for group_by_index in groups.iter() {
389 let attributes = group_by_index.get_attr_handle::<CtrlAttrMcastGrp>()?;
390 if let Ok(mcid) = attributes.get_attr_payload_as::<u32>(CtrlAttrMcastGrp::Id) {
391 if mcid == id {
392 let mcast_name = attributes
393 .get_attr_payload_as_with_len::<String>(CtrlAttrMcastGrp::Name)?;
394 res = Ok((name.clone(), mcast_name));
395 }
396 }
397 }
398 }
399 }
400
401 res
402 }
403}
404
405impl Drop for NlRouter {
406 fn drop(&mut self) {
407 if self.exit_sender.send(()).is_err() {
408 warn!("Failed to send shutdown message; processing thread should exit anyway");
409 }
410 }
411}
412
413pub struct NlRouterReceiverHandle<T, P> {
415 receiver: Receiver<Result<Nlmsghdr<u16, Buffer>, RouterError<u16, Buffer>>>,
416 senders: Senders,
417 needs_ack: bool,
418 seq: Option<u32>,
419 next_is_none: bool,
420 next_is_ack: bool,
421 data: PhantomData<(T, P)>,
422}
423
424impl<T, P> NlRouterReceiverHandle<T, P> {
425 fn new(
426 receiver: Receiver<Result<Nlmsghdr<u16, Buffer>, RouterError<u16, Buffer>>>,
427 senders: Senders,
428 needs_ack: bool,
429 seq: Option<u32>,
430 ) -> Self {
431 NlRouterReceiverHandle {
432 receiver,
433 senders,
434 needs_ack,
435 seq,
436 next_is_none: false,
437 next_is_ack: false,
438 data: PhantomData,
439 }
440 }
441}
442
443impl<T, P> NlRouterReceiverHandle<T, P>
444where
445 T: NlType,
446 P: Size + FromBytesWithInput<Input = usize>,
447{
448 pub fn next_typed<TT, PP>(&mut self) -> Option<Result<Nlmsghdr<TT, PP>, RouterError<TT, PP>>>
451 where
452 TT: NlType,
453 PP: Size + FromBytesWithInput<Input = usize>,
454 {
455 if self.next_is_none {
456 return None;
457 }
458
459 let mut msg = match self.receiver.recv() {
460 Ok(untyped) => match untyped {
461 Ok(u) => match u.to_typed::<TT, PP>() {
462 Ok(t) => t,
463 Err(e) => {
464 self.next_is_none = true;
465 return Some(Err(e));
466 }
467 },
468 Err(e) => {
469 self.next_is_none = true;
470 return Some(Err(match e.to_typed() {
471 Ok(e) => e,
472 Err(e) => e,
473 }));
474 }
475 },
476 Err(_) => {
477 self.next_is_none = true;
478 return Some(Err(RouterError::ClosedChannel));
479 }
480 };
481
482 let nl_type = Nlmsg::from((*msg.nl_type()).into());
483 if let NlPayload::Ack(_) = msg.nl_payload() {
484 self.next_is_none = true;
485 if !self.needs_ack {
486 return Some(Err(RouterError::UnexpectedAck));
487 }
488 } else if let Some(e) = msg.get_err() {
489 self.next_is_none = true;
490 if self.next_is_ack {
491 return Some(Err(RouterError::NoAck));
492 } else {
493 return Some(Err(RouterError::<TT, PP>::Nlmsgerr(e)));
494 }
495 } else if (!msg.nl_flags().contains(NlmF::MULTI) || nl_type == Nlmsg::Done)
496 && self.seq.is_some()
497 {
498 assert!(!self.next_is_ack);
499
500 if self.needs_ack {
501 self.next_is_ack = true;
502 } else {
503 self.next_is_none = true;
504 }
505 } else if self.next_is_ack {
506 self.next_is_none = true;
507 return Some(Err(RouterError::NoAck));
508 }
509
510 trace!("Router received message: {msg:?}");
511
512 Some(Ok(msg))
513 }
514}
515
516impl<T, P> Iterator for NlRouterReceiverHandle<T, P>
517where
518 T: NlType,
519 P: Size + FromBytesWithInput<Input = usize>,
520{
521 type Item = Result<Nlmsghdr<T, P>, RouterError<T, P>>;
522
523 fn next(&mut self) -> Option<Self::Item> {
524 self.next_typed::<T, P>()
525 }
526}
527
528impl<T, P> Drop for NlRouterReceiverHandle<T, P> {
529 fn drop(&mut self) {
530 if let Some(seq) = self.seq {
531 self.senders.lock().remove(&seq);
532 }
533 }
534}
535
536#[cfg(test)]
537mod test {
538 use super::*;
539
540 use crate::test::setup;
541
542 #[test]
543 fn real_test_mcast_groups() {
544 setup();
545
546 let (sock, _multicast) =
547 NlRouter::connect(NlFamily::Generic, None, Groups::empty()).unwrap();
548 sock.enable_strict_checking(true).unwrap();
549 let notify_id_result = sock.resolve_nl_mcast_group("nlctrl", "notify");
550 let config_id_result = sock.resolve_nl_mcast_group("devlink", "config");
551
552 let ids = match (notify_id_result, config_id_result) {
553 (Ok(ni), Ok(ci)) => {
554 sock.add_mcast_membership(Groups::new_groups(&[ni, ci]))
555 .unwrap();
556 vec![ni, ci]
557 }
558 (Ok(ni), Err(RouterError::Nlmsgerr(_))) => {
559 sock.add_mcast_membership(Groups::new_groups(&[ni]))
560 .unwrap();
561 vec![ni]
562 }
563 (Err(RouterError::Nlmsgerr(_)), Ok(ci)) => {
564 sock.add_mcast_membership(Groups::new_groups(&[ci]))
565 .unwrap();
566 vec![ci]
567 }
568 (Err(RouterError::Nlmsgerr(_)), Err(RouterError::Nlmsgerr(_))) => {
569 return;
570 }
571 (Err(e), _) => panic!("Unexpected result from resolve_nl_mcast_group: {e:?}"),
572 (_, Err(e)) => panic!("Unexpected result from resolve_nl_mcast_group: {e:?}"),
573 };
574
575 let groups = sock.list_mcast_membership().unwrap();
576 for id in ids.iter() {
577 assert!(groups.is_set(*id as usize));
578 }
579
580 sock.drop_mcast_membership(Groups::new_groups(ids.as_slice()))
581 .unwrap();
582 let groups = sock.list_mcast_membership().unwrap();
583
584 for id in ids.iter() {
585 assert!(!groups.is_set(*id as usize));
586 }
587 }
588}