1use std::sync::Arc;
2
3use atomic_refcell::AtomicRefCell;
4use inet::{InetSocket, InetSocketRef, InetSocketRefMut};
5use linux_api::errno::Errno;
6use linux_api::ioctls::IoctlRequest;
7use linux_api::socket::Shutdown;
8use netlink::NetlinkSocket;
9use shadow_shim_helper_rs::syscall_types::ForeignPtr;
10use unix::UnixSocket;
11
12use crate::cshadow as c;
13use crate::host::descriptor::listener::{StateListenHandle, StateListenerFilter};
14use crate::host::descriptor::{
15 FileMode, FileSignals, FileState, FileStatus, OpenFile, SyscallResult,
16};
17use crate::host::memory_manager::MemoryManager;
18use crate::host::network::namespace::NetworkNamespace;
19use crate::host::syscall::io::IoVec;
20use crate::host::syscall::types::{ForeignArrayPtr, SyscallError};
21use crate::utility::HostTreePointer;
22use crate::utility::callback_queue::CallbackQueue;
23use crate::utility::sockaddr::SockaddrStorage;
24
25pub mod abstract_unix_ns;
26pub mod inet;
27pub mod netlink;
28pub mod unix;
29
30bitflags::bitflags! {
31 #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
34 struct ShutdownFlags: u8 {
35 const READ = 0b00000001;
36 const WRITE = 0b00000010;
37 }
38}
39
40#[derive(Clone)]
41pub enum Socket {
42 Unix(Arc<AtomicRefCell<UnixSocket>>),
43 Inet(InetSocket),
44 Netlink(Arc<AtomicRefCell<NetlinkSocket>>),
45}
46
47impl Socket {
48 pub fn borrow(&self) -> SocketRef {
49 match self {
50 Self::Unix(f) => SocketRef::Unix(f.borrow()),
51 Self::Inet(f) => SocketRef::Inet(f.borrow()),
52 Self::Netlink(f) => SocketRef::Netlink(f.borrow()),
53 }
54 }
55
56 pub fn try_borrow(&self) -> Result<SocketRef, atomic_refcell::BorrowError> {
57 Ok(match self {
58 Self::Unix(f) => SocketRef::Unix(f.try_borrow()?),
59 Self::Inet(f) => SocketRef::Inet(f.try_borrow()?),
60 Self::Netlink(f) => SocketRef::Netlink(f.try_borrow()?),
61 })
62 }
63
64 pub fn borrow_mut(&self) -> SocketRefMut {
65 match self {
66 Self::Unix(f) => SocketRefMut::Unix(f.borrow_mut()),
67 Self::Inet(f) => SocketRefMut::Inet(f.borrow_mut()),
68 Self::Netlink(f) => SocketRefMut::Netlink(f.borrow_mut()),
69 }
70 }
71
72 pub fn try_borrow_mut(&self) -> Result<SocketRefMut, atomic_refcell::BorrowMutError> {
73 Ok(match self {
74 Self::Unix(f) => SocketRefMut::Unix(f.try_borrow_mut()?),
75 Self::Inet(f) => SocketRefMut::Inet(f.try_borrow_mut()?),
76 Self::Netlink(f) => SocketRefMut::Netlink(f.try_borrow_mut()?),
77 })
78 }
79
80 pub fn canonical_handle(&self) -> usize {
81 match self {
82 Self::Unix(f) => Arc::as_ptr(f) as usize,
83 Self::Inet(f) => f.canonical_handle(),
84 Self::Netlink(f) => Arc::as_ptr(f) as usize,
85 }
86 }
87
88 pub fn bind(
89 &self,
90 addr: Option<&SockaddrStorage>,
91 net_ns: &NetworkNamespace,
92 rng: impl rand::Rng,
93 ) -> Result<(), SyscallError> {
94 match self {
95 Self::Unix(socket) => UnixSocket::bind(socket, addr, net_ns, rng),
96 Self::Inet(socket) => InetSocket::bind(socket, addr, net_ns, rng),
97 Self::Netlink(socket) => NetlinkSocket::bind(socket, addr, net_ns, rng),
98 }
99 }
100
101 pub fn listen(
102 &self,
103 backlog: i32,
104 net_ns: &NetworkNamespace,
105 rng: impl rand::Rng,
106 cb_queue: &mut CallbackQueue,
107 ) -> Result<(), Errno> {
108 match self {
109 Self::Unix(socket) => UnixSocket::listen(socket, backlog, net_ns, rng, cb_queue),
110 Self::Inet(socket) => InetSocket::listen(socket, backlog, net_ns, rng, cb_queue),
111 Self::Netlink(socket) => NetlinkSocket::listen(socket, backlog, net_ns, rng, cb_queue),
112 }
113 }
114
115 pub fn connect(
116 &self,
117 addr: &SockaddrStorage,
118 net_ns: &NetworkNamespace,
119 rng: impl rand::Rng,
120 cb_queue: &mut CallbackQueue,
121 ) -> Result<(), SyscallError> {
122 match self {
123 Self::Unix(socket) => UnixSocket::connect(socket, addr, net_ns, rng, cb_queue),
124 Self::Inet(socket) => InetSocket::connect(socket, addr, net_ns, rng, cb_queue),
125 Self::Netlink(socket) => NetlinkSocket::connect(socket, addr, net_ns, rng, cb_queue),
126 }
127 }
128
129 pub fn sendmsg(
130 &self,
131 args: SendmsgArgs,
132 memory_manager: &mut MemoryManager,
133 net_ns: &NetworkNamespace,
134 rng: impl rand::Rng,
135 cb_queue: &mut CallbackQueue,
136 ) -> Result<libc::ssize_t, SyscallError> {
137 match self {
138 Self::Unix(socket) => {
139 UnixSocket::sendmsg(socket, args, memory_manager, net_ns, rng, cb_queue)
140 }
141 Self::Inet(socket) => {
142 InetSocket::sendmsg(socket, args, memory_manager, net_ns, rng, cb_queue)
143 }
144 Self::Netlink(socket) => {
145 NetlinkSocket::sendmsg(socket, args, memory_manager, net_ns, rng, cb_queue)
146 }
147 }
148 }
149
150 pub fn recvmsg(
151 &self,
152 args: RecvmsgArgs,
153 memory_manager: &mut MemoryManager,
154 cb_queue: &mut CallbackQueue,
155 ) -> Result<RecvmsgReturn, SyscallError> {
156 match self {
157 Self::Unix(socket) => UnixSocket::recvmsg(socket, args, memory_manager, cb_queue),
158 Self::Inet(socket) => InetSocket::recvmsg(socket, args, memory_manager, cb_queue),
159 Self::Netlink(socket) => NetlinkSocket::recvmsg(socket, args, memory_manager, cb_queue),
160 }
161 }
162}
163
164impl std::fmt::Debug for Socket {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 match self {
167 Self::Unix(_) => write!(f, "Unix")?,
168 Self::Inet(_) => write!(f, "Inet")?,
169 Self::Netlink(_) => write!(f, "Netlink")?,
170 }
171
172 if let Ok(file) = self.try_borrow() {
173 write!(
174 f,
175 "(state: {:?}, status: {:?})",
176 file.state(),
177 file.status()
178 )
179 } else {
180 write!(f, "(already borrowed)")
181 }
182 }
183}
184
185pub enum SocketRef<'a> {
186 Unix(atomic_refcell::AtomicRef<'a, UnixSocket>),
187 Inet(InetSocketRef<'a>),
188 Netlink(atomic_refcell::AtomicRef<'a, NetlinkSocket>),
189}
190
191pub enum SocketRefMut<'a> {
192 Unix(atomic_refcell::AtomicRefMut<'a, UnixSocket>),
193 Inet(InetSocketRefMut<'a>),
194 Netlink(atomic_refcell::AtomicRefMut<'a, NetlinkSocket>),
195}
196
197impl SocketRef<'_> {
199 enum_passthrough!(self, (), Unix, Inet, Netlink;
200 pub fn state(&self) -> FileState
201 );
202 enum_passthrough!(self, (), Unix, Inet, Netlink;
203 pub fn mode(&self) -> FileMode
204 );
205 enum_passthrough!(self, (), Unix, Inet, Netlink;
206 pub fn status(&self) -> FileStatus
207 );
208 enum_passthrough!(self, (), Unix, Inet, Netlink;
209 pub fn stat(&self) -> Result<linux_api::stat::stat, SyscallError>
210 );
211 enum_passthrough!(self, (), Unix, Inet, Netlink;
212 pub fn has_open_file(&self) -> bool
213 );
214 enum_passthrough!(self, (), Unix, Inet, Netlink;
215 pub fn supports_sa_restart(&self) -> bool
216 );
217}
218
219impl SocketRef<'_> {
221 pub fn getpeername(&self) -> Result<Option<SockaddrStorage>, Errno> {
222 match self {
223 Self::Unix(socket) => socket.getpeername().map(|opt| opt.map(Into::into)),
224 Self::Inet(socket) => socket.getpeername(),
225 Self::Netlink(socket) => socket.getpeername().map(|opt| opt.map(Into::into)),
226 }
227 }
228
229 pub fn getsockname(&self) -> Result<Option<SockaddrStorage>, Errno> {
230 match self {
231 Self::Unix(socket) => socket.getsockname().map(|opt| opt.map(Into::into)),
232 Self::Inet(socket) => socket.getsockname(),
233 Self::Netlink(socket) => socket.getsockname().map(|opt| opt.map(Into::into)),
234 }
235 }
236
237 enum_passthrough!(self, (), Unix, Inet, Netlink;
238 pub fn address_family(&self) -> linux_api::socket::AddressFamily
239 );
240}
241
242impl SocketRefMut<'_> {
244 enum_passthrough!(self, (), Unix, Inet, Netlink;
245 pub fn state(&self) -> FileState
246 );
247 enum_passthrough!(self, (), Unix, Inet, Netlink;
248 pub fn mode(&self) -> FileMode
249 );
250 enum_passthrough!(self, (), Unix, Inet, Netlink;
251 pub fn status(&self) -> FileStatus
252 );
253 enum_passthrough!(self, (), Unix, Inet, Netlink;
254 pub fn stat(&self) -> Result<linux_api::stat::stat, SyscallError>
255 );
256 enum_passthrough!(self, (), Unix, Inet, Netlink;
257 pub fn has_open_file(&self) -> bool
258 );
259 enum_passthrough!(self, (val), Unix, Inet, Netlink;
260 pub fn set_has_open_file(&mut self, val: bool)
261 );
262 enum_passthrough!(self, (), Unix, Inet, Netlink;
263 pub fn supports_sa_restart(&self) -> bool
264 );
265 enum_passthrough!(self, (cb_queue), Unix, Inet, Netlink;
266 pub fn close(&mut self, cb_queue: &mut CallbackQueue) -> Result<(), SyscallError>
267 );
268 enum_passthrough!(self, (status), Unix, Inet, Netlink;
269 pub fn set_status(&mut self, status: FileStatus)
270 );
271 enum_passthrough!(self, (request, arg_ptr, memory_manager), Unix, Inet, Netlink;
272 pub fn ioctl(&mut self, request: IoctlRequest, arg_ptr: ForeignPtr<()>, memory_manager: &mut MemoryManager) -> SyscallResult
273 );
274 enum_passthrough!(self, (monitoring_state, monitoring_signals, filter, notify_fn), Unix, Inet, Netlink;
275 pub fn add_listener(
276 &mut self,
277 monitoring_state: FileState,
278 monitoring_signals: FileSignals,
279 filter: StateListenerFilter,
280 notify_fn: impl Fn(FileState, FileState, FileSignals, &mut CallbackQueue) + Send + Sync + 'static,
281 ) -> StateListenHandle
282 );
283 enum_passthrough!(self, (ptr), Unix, Inet, Netlink;
284 pub fn add_legacy_listener(&mut self, ptr: HostTreePointer<c::StatusListener>)
285 );
286 enum_passthrough!(self, (ptr), Unix, Inet, Netlink;
287 pub fn remove_legacy_listener(&mut self, ptr: *mut c::StatusListener)
288 );
289 enum_passthrough!(self, (iovs, offset, flags, mem, cb_queue), Unix, Inet, Netlink;
290 pub fn readv(&mut self, iovs: &[IoVec], offset: Option<libc::off_t>, flags: libc::c_int,
291 mem: &mut MemoryManager, cb_queue: &mut CallbackQueue) -> Result<libc::ssize_t, SyscallError>
292 );
293 enum_passthrough!(self, (iovs, offset, flags, mem, cb_queue), Unix, Inet, Netlink;
294 pub fn writev(&mut self, iovs: &[IoVec], offset: Option<libc::off_t>, flags: libc::c_int,
295 mem: &mut MemoryManager, cb_queue: &mut CallbackQueue) -> Result<libc::ssize_t, SyscallError>
296 );
297}
298
299impl SocketRefMut<'_> {
301 pub fn getpeername(&self) -> Result<Option<SockaddrStorage>, Errno> {
302 match self {
303 Self::Unix(socket) => socket.getpeername().map(|opt| opt.map(Into::into)),
304 Self::Inet(socket) => socket.getpeername(),
305 Self::Netlink(socket) => socket.getpeername().map(|opt| opt.map(Into::into)),
306 }
307 }
308
309 pub fn getsockname(&self) -> Result<Option<SockaddrStorage>, Errno> {
310 match self {
311 Self::Unix(socket) => socket.getsockname().map(|opt| opt.map(Into::into)),
312 Self::Inet(socket) => socket.getsockname(),
313 Self::Netlink(socket) => socket.getsockname().map(|opt| opt.map(Into::into)),
314 }
315 }
316
317 enum_passthrough!(self, (), Unix, Inet, Netlink;
318 pub fn address_family(&self) -> linux_api::socket::AddressFamily
319 );
320
321 enum_passthrough!(self, (level, optname, optval_ptr, optlen, memory_manager, cb_queue), Unix, Inet, Netlink;
322 pub fn getsockopt(&mut self, level: libc::c_int, optname: libc::c_int, optval_ptr: ForeignPtr<()>,
323 optlen: libc::socklen_t, memory_manager: &mut MemoryManager, cb_queue: &mut CallbackQueue)
324 -> Result<libc::socklen_t, SyscallError>
325 );
326
327 enum_passthrough!(self, (level, optname, optval_ptr, optlen, memory_manager), Unix, Inet, Netlink;
328 pub fn setsockopt(&mut self, level: libc::c_int, optname: libc::c_int, optval_ptr: ForeignPtr<()>,
329 optlen: libc::socklen_t, memory_manager: &MemoryManager)
330 -> Result<(), SyscallError>
331 );
332
333 pub fn accept(
334 &mut self,
335 net_ns: &NetworkNamespace,
336 rng: impl rand::Rng,
337 cb_queue: &mut CallbackQueue,
338 ) -> Result<OpenFile, SyscallError> {
339 match self {
340 Self::Unix(socket) => socket.accept(net_ns, rng, cb_queue),
341 Self::Inet(socket) => socket.accept(net_ns, rng, cb_queue),
342 Self::Netlink(socket) => socket.accept(net_ns, rng, cb_queue),
343 }
344 }
345
346 enum_passthrough!(self, (how, cb_queue), Unix, Inet, Netlink;
347 pub fn shutdown(&mut self, how: Shutdown, cb_queue: &mut CallbackQueue) -> Result<(), SyscallError>
348 );
349}
350
351impl std::fmt::Debug for SocketRef<'_> {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353 match self {
354 Self::Unix(_) => write!(f, "Unix")?,
355 Self::Inet(_) => write!(f, "Inet")?,
356 Self::Netlink(_) => write!(f, "Netlink")?,
357 }
358
359 write!(
360 f,
361 "(state: {:?}, status: {:?})",
362 self.state(),
363 self.status()
364 )
365 }
366}
367
368impl std::fmt::Debug for SocketRefMut<'_> {
369 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 match self {
371 Self::Unix(_) => write!(f, "Unix")?,
372 Self::Inet(_) => write!(f, "Inet")?,
373 Self::Netlink(_) => write!(f, "Netlink")?,
374 }
375
376 write!(
377 f,
378 "(state: {:?}, status: {:?})",
379 self.state(),
380 self.status()
381 )
382 }
383}
384
385pub struct SendmsgArgs<'a> {
387 pub addr: Option<SockaddrStorage>,
389 pub iovs: &'a [IoVec],
391 pub control_ptr: ForeignArrayPtr<u8>,
393 pub flags: libc::c_int,
395}
396
397pub struct RecvmsgArgs<'a> {
399 pub iovs: &'a [IoVec],
401 pub control_ptr: ForeignArrayPtr<u8>,
403 pub flags: libc::c_int,
405}
406
407pub struct RecvmsgReturn {
409 pub return_val: libc::ssize_t,
412 pub addr: Option<SockaddrStorage>,
414 pub msg_flags: libc::c_int,
416 pub control_len: libc::size_t,
418}