neli/
iter.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
//! Module for iteration over netlink responses

use std::{fmt::Debug, marker::PhantomData};

use crate::{
    consts::nl::{NlType, NlmF, Nlmsg},
    err::NlError,
    nl::{NlPayload, Nlmsghdr},
    socket::NlSocketHandle,
    FromBytesWithInput,
};

/// Define iteration behavior when traversing a stream of netlink
/// messages.
#[derive(PartialEq, Eq)]
pub enum IterationBehavior {
    /// End iteration of multi-part messages when a DONE message is
    /// reached.
    EndMultiOnDone,
    /// Iterate indefinitely. Mostly useful for multicast
    /// subscriptions.
    IterIndefinitely,
}

/// Iterator over messages in an
/// [`NlSocketHandle`][crate::socket::NlSocketHandle] type.
///
/// This iterator has two high-level options:
/// * Iterate indefinitely over messages. This is most
/// useful in the case of subscribing to messages in a
/// multicast group.
/// * Iterate until a message is returned with
/// [`Nlmsg::Done`][crate::consts::nl::Nlmsg::Done] is set.
/// This is most useful in the case of request-response workflows
/// where the iterator will parse and iterate through all of the
/// messages with [`NlmF::Multi`][crate::consts::nl::NlmF::Multi] set
/// until a message with
/// [`Nlmsg::Done`][crate::consts::nl::Nlmsg::Done] is
/// received at which point [`None`] will be returned indicating the
/// end of the response.
pub struct NlMessageIter<'a, T, P> {
    sock_ref: &'a mut NlSocketHandle,
    next_is_none: Option<bool>,
    type_: PhantomData<T>,
    payload: PhantomData<P>,
}

impl<'a, T, P> NlMessageIter<'a, T, P>
where
    T: NlType + Debug,
    P: FromBytesWithInput<'a, Input = usize> + Debug,
{
    /// Construct a new iterator that yields
    /// [`Nlmsghdr`][crate::nl::Nlmsghdr] structs from the provided
    /// buffer. `behavior` set to
    /// [`IterationBehavior::IterIndefinitely`] will treat
    /// messages as a never-ending stream.
    /// [`IterationBehavior::EndMultiOnDone`] will cause
    /// [`NlMessageIter`] to respect the netlink identifiers
    /// [`NlmF::Multi`][crate::consts::nl::NlmF::Multi] and
    /// [`Nlmsg::Done`][crate::consts::nl::Nlmsg::Done].
    ///
    /// If `behavior` is [`IterationBehavior::EndMultiOnDone`],
    /// this means that [`NlMessageIter`] will iterate through
    /// either exactly one message if
    /// [`NlmF::Multi`][crate::consts::nl::NlmF::Multi] is not
    /// set, or through all consecutive messages with
    /// [`NlmF::Multi`][crate::consts::nl::NlmF::Multi] set until
    /// a terminating message with
    /// [`Nlmsg::Done`][crate::consts::nl::Nlmsg::Done] is reached
    /// at which point [`None`] will be returned by the iterator.
    pub fn new(sock_ref: &'a mut NlSocketHandle, behavior: IterationBehavior) -> Self {
        NlMessageIter {
            sock_ref,
            next_is_none: if behavior == IterationBehavior::IterIndefinitely {
                None
            } else {
                Some(false)
            },
            type_: PhantomData,
            payload: PhantomData,
        }
    }

    fn next<TT, PP>(&mut self) -> Option<Result<Nlmsghdr<TT, PP>, NlError<TT, PP>>>
    where
        TT: NlType + Debug,
        PP: for<'c> FromBytesWithInput<'c, Input = usize> + Debug,
    {
        if let Some(true) = self.next_is_none {
            return None;
        }

        let next_res = self.sock_ref.recv::<TT, PP>();
        let next = match next_res {
            Ok(Some(n)) => n,
            Ok(None) => return None,
            Err(e) => return Some(Err(e)),
        };

        if let NlPayload::Ack(_) = next.nl_payload {
            self.next_is_none = self.next_is_none.map(|_| true);
        } else if (!next.nl_flags.contains(&NlmF::Multi)
            || next.nl_type.into() == Nlmsg::Done.into())
            && !self.sock_ref.needs_ack
        {
            self.next_is_none = self.next_is_none.map(|_| true);
        }

        Some(Ok(next))
    }
}

impl<'a, T, P> Iterator for NlMessageIter<'a, T, P>
where
    T: NlType + Debug,
    P: for<'b> FromBytesWithInput<'b, Input = usize> + Debug,
{
    type Item = Result<Nlmsghdr<T, P>, NlError<T, P>>;

    fn next(&mut self) -> Option<Self::Item> {
        NlMessageIter::next::<T, P>(self)
    }
}