neli/socket/
synchronous.rs1use std::{
2 fmt::Debug,
3 io::Cursor,
4 os::unix::io::{AsRawFd, IntoRawFd, RawFd},
5};
6
7use log::trace;
8
9use crate::{
10 FromBytesWithInput, Size, ToBytes,
11 consts::{nl::*, socket::*},
12 err::SocketError,
13 iter::NlBufferIter,
14 nl::Nlmsghdr,
15 socket::shared::NlSocket,
16 types::NlBuffer,
17 utils::{
18 Groups, NetlinkBitArray,
19 synchronous::{BufferPool, BufferPoolGuard},
20 },
21};
22
23pub struct NlSocketHandle {
25 pub(super) socket: NlSocket,
26 pid: u32,
27 pool: BufferPool,
28}
29
30impl NlSocketHandle {
31 pub fn connect(proto: NlFamily, pid: Option<u32>, groups: Groups) -> Result<Self, SocketError> {
33 let socket = NlSocket::connect(proto, pid, groups)?;
34 socket.block()?;
35 let pid = socket.pid()?;
36 Ok(NlSocketHandle {
37 socket,
38 pid,
39 pool: BufferPool::default(),
40 })
41 }
42
43 pub fn add_mcast_membership(&self, groups: Groups) -> Result<(), SocketError> {
45 self.socket
46 .add_mcast_membership(groups)
47 .map_err(SocketError::from)
48 }
49
50 pub fn drop_mcast_membership(&self, groups: Groups) -> Result<(), SocketError> {
52 self.socket
53 .drop_mcast_membership(groups)
54 .map_err(SocketError::from)
55 }
56
57 pub fn list_mcast_membership(&self) -> Result<NetlinkBitArray, SocketError> {
59 self.socket
60 .list_mcast_membership()
61 .map_err(SocketError::from)
62 }
63
64 pub fn pid(&self) -> u32 {
66 self.pid
67 }
68
69 pub fn send<T, P>(&self, msg: &Nlmsghdr<T, P>) -> Result<(), SocketError>
71 where
72 T: NlType + Debug,
73 P: Size + ToBytes + Debug,
74 {
75 trace!("Message sent:\n{msg:?}");
76
77 let mut buffer = Cursor::new(vec![0; msg.padded_size()]);
78 msg.to_bytes(&mut buffer)?;
79 trace!("Buffer sent: {:?}", buffer.get_ref());
80 self.socket.send(buffer.get_ref(), Msg::empty())?;
81
82 Ok(())
83 }
84
85 pub fn recv<T, P>(
94 &self,
95 ) -> Result<(NlBufferIter<T, P, BufferPoolGuard<'_>>, Groups), SocketError>
96 where
97 T: NlType + Debug,
98 P: Size + FromBytesWithInput<Input = usize> + Debug,
99 {
100 let mut buffer = self.pool.acquire();
101 let (mem_read, groups) = self.socket.recv(&mut buffer, Msg::empty())?;
102 buffer.reduce_size(mem_read);
103 trace!("Buffer received: {:?}", buffer.as_ref());
104 Ok((NlBufferIter::new(Cursor::new(buffer)), groups))
105 }
106
107 pub fn recv_all<T, P>(&self) -> Result<(NlBuffer<T, P>, Groups), SocketError>
116 where
117 T: NlType + Debug,
118 P: Size + FromBytesWithInput<Input = usize> + Debug,
119 {
120 let mut buffer = self.pool.acquire();
121 let (mem_read, groups) = self.socket.recv(&mut buffer, Msg::empty())?;
122 if mem_read == 0 {
123 return Ok((NlBuffer::new(), Groups::empty()));
124 }
125 buffer.reduce_size(mem_read);
126
127 let vec = NlBuffer::from_bytes_with_input(&mut Cursor::new(buffer), mem_read)?;
128
129 trace!("Messages received: {vec:?}");
130
131 Ok((vec, groups))
132 }
133
134 pub fn set_recv_buffer_size(&self, size: usize) -> Result<(), SocketError> {
148 self.socket
149 .set_recv_buffer_size(size)
150 .map_err(SocketError::from)
151 }
152
153 pub fn enable_ext_ack(&self, enable: bool) -> Result<(), SocketError> {
156 self.socket
157 .enable_ext_ack(enable)
158 .map_err(SocketError::from)
159 }
160
161 pub fn get_ext_ack_enabled(&self) -> Result<bool, SocketError> {
163 self.socket.get_ext_ack_enabled().map_err(SocketError::from)
164 }
165
166 pub fn enable_strict_checking(&self, enable: bool) -> Result<(), SocketError> {
171 self.socket
172 .enable_strict_checking(enable)
173 .map_err(SocketError::from)
174 }
175
176 pub fn get_strict_checking_enabled(&self) -> Result<bool, SocketError> {
180 self.socket
181 .get_strict_checking_enabled()
182 .map_err(SocketError::from)
183 }
184}
185
186impl AsRawFd for NlSocketHandle {
187 fn as_raw_fd(&self) -> RawFd {
188 self.socket.as_raw_fd()
189 }
190}
191
192impl IntoRawFd for NlSocketHandle {
193 fn into_raw_fd(self) -> RawFd {
194 self.socket.into_raw_fd()
195 }
196}