shadow_rs/host/descriptor/socket/
mod.rs

1use std::sync::Arc;
2
3use atomic_refcell::AtomicRefCell;
4use inet::{InetSocket, InetSocketRef, InetSocketRefMut};
5use linux_api::errno::Errno;
6use linux_api::ioctls::IoctlRequest;
7use linux_api::socket::Shutdown;
8use netlink::NetlinkSocket;
9use shadow_shim_helper_rs::syscall_types::ForeignPtr;
10use unix::UnixSocket;
11
12use crate::cshadow as c;
13use crate::host::descriptor::listener::{StateListenHandle, StateListenerFilter};
14use crate::host::descriptor::{
15    FileMode, FileSignals, FileState, FileStatus, OpenFile, SyscallResult,
16};
17use crate::host::memory_manager::MemoryManager;
18use crate::host::network::namespace::NetworkNamespace;
19use crate::host::syscall::io::IoVec;
20use crate::host::syscall::types::{ForeignArrayPtr, SyscallError};
21use crate::utility::HostTreePointer;
22use crate::utility::callback_queue::CallbackQueue;
23use crate::utility::sockaddr::SockaddrStorage;
24
25pub mod abstract_unix_ns;
26pub mod inet;
27pub mod netlink;
28pub mod unix;
29
30bitflags::bitflags! {
31    /// Flags to represent if a socket has been shut down for reading and/or writing. An empty set
32    /// of flags implies that the socket *has not* been shut down for reading or writing.
33    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
34    struct ShutdownFlags: u8 {
35        const READ = 0b00000001;
36        const WRITE = 0b00000010;
37    }
38}
39
40#[derive(Clone)]
41pub enum Socket {
42    Unix(Arc<AtomicRefCell<UnixSocket>>),
43    Inet(InetSocket),
44    Netlink(Arc<AtomicRefCell<NetlinkSocket>>),
45}
46
47impl Socket {
48    pub fn borrow(&self) -> SocketRef {
49        match self {
50            Self::Unix(f) => SocketRef::Unix(f.borrow()),
51            Self::Inet(f) => SocketRef::Inet(f.borrow()),
52            Self::Netlink(f) => SocketRef::Netlink(f.borrow()),
53        }
54    }
55
56    pub fn try_borrow(&self) -> Result<SocketRef, atomic_refcell::BorrowError> {
57        Ok(match self {
58            Self::Unix(f) => SocketRef::Unix(f.try_borrow()?),
59            Self::Inet(f) => SocketRef::Inet(f.try_borrow()?),
60            Self::Netlink(f) => SocketRef::Netlink(f.try_borrow()?),
61        })
62    }
63
64    pub fn borrow_mut(&self) -> SocketRefMut {
65        match self {
66            Self::Unix(f) => SocketRefMut::Unix(f.borrow_mut()),
67            Self::Inet(f) => SocketRefMut::Inet(f.borrow_mut()),
68            Self::Netlink(f) => SocketRefMut::Netlink(f.borrow_mut()),
69        }
70    }
71
72    pub fn try_borrow_mut(&self) -> Result<SocketRefMut, atomic_refcell::BorrowMutError> {
73        Ok(match self {
74            Self::Unix(f) => SocketRefMut::Unix(f.try_borrow_mut()?),
75            Self::Inet(f) => SocketRefMut::Inet(f.try_borrow_mut()?),
76            Self::Netlink(f) => SocketRefMut::Netlink(f.try_borrow_mut()?),
77        })
78    }
79
80    pub fn canonical_handle(&self) -> usize {
81        match self {
82            Self::Unix(f) => Arc::as_ptr(f) as usize,
83            Self::Inet(f) => f.canonical_handle(),
84            Self::Netlink(f) => Arc::as_ptr(f) as usize,
85        }
86    }
87
88    pub fn bind(
89        &self,
90        addr: Option<&SockaddrStorage>,
91        net_ns: &NetworkNamespace,
92        rng: impl rand::Rng,
93    ) -> Result<(), SyscallError> {
94        match self {
95            Self::Unix(socket) => UnixSocket::bind(socket, addr, net_ns, rng),
96            Self::Inet(socket) => InetSocket::bind(socket, addr, net_ns, rng),
97            Self::Netlink(socket) => NetlinkSocket::bind(socket, addr, net_ns, rng),
98        }
99    }
100
101    pub fn listen(
102        &self,
103        backlog: i32,
104        net_ns: &NetworkNamespace,
105        rng: impl rand::Rng,
106        cb_queue: &mut CallbackQueue,
107    ) -> Result<(), Errno> {
108        match self {
109            Self::Unix(socket) => UnixSocket::listen(socket, backlog, net_ns, rng, cb_queue),
110            Self::Inet(socket) => InetSocket::listen(socket, backlog, net_ns, rng, cb_queue),
111            Self::Netlink(socket) => NetlinkSocket::listen(socket, backlog, net_ns, rng, cb_queue),
112        }
113    }
114
115    pub fn connect(
116        &self,
117        addr: &SockaddrStorage,
118        net_ns: &NetworkNamespace,
119        rng: impl rand::Rng,
120        cb_queue: &mut CallbackQueue,
121    ) -> Result<(), SyscallError> {
122        match self {
123            Self::Unix(socket) => UnixSocket::connect(socket, addr, net_ns, rng, cb_queue),
124            Self::Inet(socket) => InetSocket::connect(socket, addr, net_ns, rng, cb_queue),
125            Self::Netlink(socket) => NetlinkSocket::connect(socket, addr, net_ns, rng, cb_queue),
126        }
127    }
128
129    pub fn sendmsg(
130        &self,
131        args: SendmsgArgs,
132        memory_manager: &mut MemoryManager,
133        net_ns: &NetworkNamespace,
134        rng: impl rand::Rng,
135        cb_queue: &mut CallbackQueue,
136    ) -> Result<libc::ssize_t, SyscallError> {
137        match self {
138            Self::Unix(socket) => {
139                UnixSocket::sendmsg(socket, args, memory_manager, net_ns, rng, cb_queue)
140            }
141            Self::Inet(socket) => {
142                InetSocket::sendmsg(socket, args, memory_manager, net_ns, rng, cb_queue)
143            }
144            Self::Netlink(socket) => {
145                NetlinkSocket::sendmsg(socket, args, memory_manager, net_ns, rng, cb_queue)
146            }
147        }
148    }
149
150    pub fn recvmsg(
151        &self,
152        args: RecvmsgArgs,
153        memory_manager: &mut MemoryManager,
154        cb_queue: &mut CallbackQueue,
155    ) -> Result<RecvmsgReturn, SyscallError> {
156        match self {
157            Self::Unix(socket) => UnixSocket::recvmsg(socket, args, memory_manager, cb_queue),
158            Self::Inet(socket) => InetSocket::recvmsg(socket, args, memory_manager, cb_queue),
159            Self::Netlink(socket) => NetlinkSocket::recvmsg(socket, args, memory_manager, cb_queue),
160        }
161    }
162}
163
164impl std::fmt::Debug for Socket {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        match self {
167            Self::Unix(_) => write!(f, "Unix")?,
168            Self::Inet(_) => write!(f, "Inet")?,
169            Self::Netlink(_) => write!(f, "Netlink")?,
170        }
171
172        if let Ok(file) = self.try_borrow() {
173            write!(
174                f,
175                "(state: {:?}, status: {:?})",
176                file.state(),
177                file.status()
178            )
179        } else {
180            write!(f, "(already borrowed)")
181        }
182    }
183}
184
185pub enum SocketRef<'a> {
186    Unix(atomic_refcell::AtomicRef<'a, UnixSocket>),
187    Inet(InetSocketRef<'a>),
188    Netlink(atomic_refcell::AtomicRef<'a, NetlinkSocket>),
189}
190
191pub enum SocketRefMut<'a> {
192    Unix(atomic_refcell::AtomicRefMut<'a, UnixSocket>),
193    Inet(InetSocketRefMut<'a>),
194    Netlink(atomic_refcell::AtomicRefMut<'a, NetlinkSocket>),
195}
196
197// file functions
198impl SocketRef<'_> {
199    enum_passthrough!(self, (), Unix, Inet, Netlink;
200        pub fn state(&self) -> FileState
201    );
202    enum_passthrough!(self, (), Unix, Inet, Netlink;
203        pub fn mode(&self) -> FileMode
204    );
205    enum_passthrough!(self, (), Unix, Inet, Netlink;
206        pub fn status(&self) -> FileStatus
207    );
208    enum_passthrough!(self, (), Unix, Inet, Netlink;
209        pub fn stat(&self) -> Result<linux_api::stat::stat, SyscallError>
210    );
211    enum_passthrough!(self, (), Unix, Inet, Netlink;
212        pub fn has_open_file(&self) -> bool
213    );
214    enum_passthrough!(self, (), Unix, Inet, Netlink;
215        pub fn supports_sa_restart(&self) -> bool
216    );
217}
218
219// socket-specific functions
220impl SocketRef<'_> {
221    pub fn getpeername(&self) -> Result<Option<SockaddrStorage>, Errno> {
222        match self {
223            Self::Unix(socket) => socket.getpeername().map(|opt| opt.map(Into::into)),
224            Self::Inet(socket) => socket.getpeername(),
225            Self::Netlink(socket) => socket.getpeername().map(|opt| opt.map(Into::into)),
226        }
227    }
228
229    pub fn getsockname(&self) -> Result<Option<SockaddrStorage>, Errno> {
230        match self {
231            Self::Unix(socket) => socket.getsockname().map(|opt| opt.map(Into::into)),
232            Self::Inet(socket) => socket.getsockname(),
233            Self::Netlink(socket) => socket.getsockname().map(|opt| opt.map(Into::into)),
234        }
235    }
236
237    enum_passthrough!(self, (), Unix, Inet, Netlink;
238        pub fn address_family(&self) -> linux_api::socket::AddressFamily
239    );
240}
241
242// file functions
243impl SocketRefMut<'_> {
244    enum_passthrough!(self, (), Unix, Inet, Netlink;
245        pub fn state(&self) -> FileState
246    );
247    enum_passthrough!(self, (), Unix, Inet, Netlink;
248        pub fn mode(&self) -> FileMode
249    );
250    enum_passthrough!(self, (), Unix, Inet, Netlink;
251        pub fn status(&self) -> FileStatus
252    );
253    enum_passthrough!(self, (), Unix, Inet, Netlink;
254        pub fn stat(&self) -> Result<linux_api::stat::stat, SyscallError>
255    );
256    enum_passthrough!(self, (), Unix, Inet, Netlink;
257        pub fn has_open_file(&self) -> bool
258    );
259    enum_passthrough!(self, (val), Unix, Inet, Netlink;
260        pub fn set_has_open_file(&mut self, val: bool)
261    );
262    enum_passthrough!(self, (), Unix, Inet, Netlink;
263        pub fn supports_sa_restart(&self) -> bool
264    );
265    enum_passthrough!(self, (cb_queue), Unix, Inet, Netlink;
266        pub fn close(&mut self, cb_queue: &mut CallbackQueue) -> Result<(), SyscallError>
267    );
268    enum_passthrough!(self, (status), Unix, Inet, Netlink;
269        pub fn set_status(&mut self, status: FileStatus)
270    );
271    enum_passthrough!(self, (request, arg_ptr, memory_manager), Unix, Inet, Netlink;
272        pub fn ioctl(&mut self, request: IoctlRequest, arg_ptr: ForeignPtr<()>, memory_manager: &mut MemoryManager) -> SyscallResult
273    );
274    enum_passthrough!(self, (monitoring_state, monitoring_signals, filter, notify_fn), Unix, Inet, Netlink;
275        pub fn add_listener(
276            &mut self,
277            monitoring_state: FileState,
278            monitoring_signals: FileSignals,
279            filter: StateListenerFilter,
280            notify_fn: impl Fn(FileState, FileState, FileSignals, &mut CallbackQueue) + Send + Sync + 'static,
281        ) -> StateListenHandle
282    );
283    enum_passthrough!(self, (ptr), Unix, Inet, Netlink;
284        pub fn add_legacy_listener(&mut self, ptr: HostTreePointer<c::StatusListener>)
285    );
286    enum_passthrough!(self, (ptr), Unix, Inet, Netlink;
287        pub fn remove_legacy_listener(&mut self, ptr: *mut c::StatusListener)
288    );
289    enum_passthrough!(self, (iovs, offset, flags, mem, cb_queue), Unix, Inet, Netlink;
290        pub fn readv(&mut self, iovs: &[IoVec], offset: Option<libc::off_t>, flags: libc::c_int,
291                     mem: &mut MemoryManager, cb_queue: &mut CallbackQueue) -> Result<libc::ssize_t, SyscallError>
292    );
293    enum_passthrough!(self, (iovs, offset, flags, mem, cb_queue), Unix, Inet, Netlink;
294        pub fn writev(&mut self, iovs: &[IoVec], offset: Option<libc::off_t>, flags: libc::c_int,
295                      mem: &mut MemoryManager, cb_queue: &mut CallbackQueue) -> Result<libc::ssize_t, SyscallError>
296    );
297}
298
299// socket-specific functions
300impl SocketRefMut<'_> {
301    pub fn getpeername(&self) -> Result<Option<SockaddrStorage>, Errno> {
302        match self {
303            Self::Unix(socket) => socket.getpeername().map(|opt| opt.map(Into::into)),
304            Self::Inet(socket) => socket.getpeername(),
305            Self::Netlink(socket) => socket.getpeername().map(|opt| opt.map(Into::into)),
306        }
307    }
308
309    pub fn getsockname(&self) -> Result<Option<SockaddrStorage>, Errno> {
310        match self {
311            Self::Unix(socket) => socket.getsockname().map(|opt| opt.map(Into::into)),
312            Self::Inet(socket) => socket.getsockname(),
313            Self::Netlink(socket) => socket.getsockname().map(|opt| opt.map(Into::into)),
314        }
315    }
316
317    enum_passthrough!(self, (), Unix, Inet, Netlink;
318        pub fn address_family(&self) -> linux_api::socket::AddressFamily
319    );
320
321    enum_passthrough!(self, (level, optname, optval_ptr, optlen, memory_manager, cb_queue), Unix, Inet, Netlink;
322        pub fn getsockopt(&mut self, level: libc::c_int, optname: libc::c_int, optval_ptr: ForeignPtr<()>,
323                          optlen: libc::socklen_t, memory_manager: &mut MemoryManager, cb_queue: &mut CallbackQueue)
324        -> Result<libc::socklen_t, SyscallError>
325    );
326
327    enum_passthrough!(self, (level, optname, optval_ptr, optlen, memory_manager), Unix, Inet, Netlink;
328        pub fn setsockopt(&mut self, level: libc::c_int, optname: libc::c_int, optval_ptr: ForeignPtr<()>,
329                          optlen: libc::socklen_t, memory_manager: &MemoryManager)
330        -> Result<(), SyscallError>
331    );
332
333    pub fn accept(
334        &mut self,
335        net_ns: &NetworkNamespace,
336        rng: impl rand::Rng,
337        cb_queue: &mut CallbackQueue,
338    ) -> Result<OpenFile, SyscallError> {
339        match self {
340            Self::Unix(socket) => socket.accept(net_ns, rng, cb_queue),
341            Self::Inet(socket) => socket.accept(net_ns, rng, cb_queue),
342            Self::Netlink(socket) => socket.accept(net_ns, rng, cb_queue),
343        }
344    }
345
346    enum_passthrough!(self, (how, cb_queue), Unix, Inet, Netlink;
347        pub fn shutdown(&mut self, how: Shutdown, cb_queue: &mut CallbackQueue) -> Result<(), SyscallError>
348    );
349}
350
351impl std::fmt::Debug for SocketRef<'_> {
352    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353        match self {
354            Self::Unix(_) => write!(f, "Unix")?,
355            Self::Inet(_) => write!(f, "Inet")?,
356            Self::Netlink(_) => write!(f, "Netlink")?,
357        }
358
359        write!(
360            f,
361            "(state: {:?}, status: {:?})",
362            self.state(),
363            self.status()
364        )
365    }
366}
367
368impl std::fmt::Debug for SocketRefMut<'_> {
369    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370        match self {
371            Self::Unix(_) => write!(f, "Unix")?,
372            Self::Inet(_) => write!(f, "Inet")?,
373            Self::Netlink(_) => write!(f, "Netlink")?,
374        }
375
376        write!(
377            f,
378            "(state: {:?}, status: {:?})",
379            self.state(),
380            self.status()
381        )
382    }
383}
384
385/// Arguments for [`Socket::sendmsg()`].
386pub struct SendmsgArgs<'a> {
387    /// Socket address to send the message to.
388    pub addr: Option<SockaddrStorage>,
389    /// [`IoVec`] buffers in plugin memory containing the message data.
390    pub iovs: &'a [IoVec],
391    /// Buffer in plugin memory containg message control data.
392    pub control_ptr: ForeignArrayPtr<u8>,
393    /// Send flags.
394    pub flags: libc::c_int,
395}
396
397/// Arguments for [`Socket::recvmsg()`].
398pub struct RecvmsgArgs<'a> {
399    /// [`IoVec`] buffers in plugin memory to store the message data.
400    pub iovs: &'a [IoVec],
401    /// Buffer in plugin memory to store the message control data.
402    pub control_ptr: ForeignArrayPtr<u8>,
403    /// Recv flags.
404    pub flags: libc::c_int,
405}
406
407/// Return values for [`Socket::recvmsg()`].
408pub struct RecvmsgReturn {
409    /// The return value for the syscall. Typically is the number of message bytes read, but is
410    /// modifiable by the syscall flag.
411    pub return_val: libc::ssize_t,
412    /// The socket address of the received message.
413    pub addr: Option<SockaddrStorage>,
414    /// Message flags.
415    pub msg_flags: libc::c_int,
416    /// The number of control data bytes read.
417    pub control_len: libc::size_t,
418}