shadow_rs/host/descriptor/socket/
abstract_unix_ns.rs1use std::collections::HashMap;
2use std::sync::{Arc, Weak};
3
4use atomic_refcell::AtomicRefCell;
5use rand::seq::IndexedRandom;
6
7use crate::host::descriptor::listener::{StateEventSource, StateListenHandle, StateListenerFilter};
8use crate::host::descriptor::socket::unix::{UnixSocket, UnixSocketType};
9use crate::host::descriptor::{FileSignals, FileState};
10
11struct NamespaceEntry {
12 socket: Weak<AtomicRefCell<UnixSocket>>,
14 _handle: StateListenHandle,
16}
17
18impl NamespaceEntry {
19 pub fn new(socket: Weak<AtomicRefCell<UnixSocket>>, handle: StateListenHandle) -> Self {
20 Self {
21 socket,
22 _handle: handle,
23 }
24 }
25}
26
27pub struct AbstractUnixNamespace {
28 address_map: HashMap<UnixSocketType, HashMap<Vec<u8>, NamespaceEntry>>,
29}
30
31impl AbstractUnixNamespace {
32 pub fn new() -> Self {
33 let mut rv = Self {
34 address_map: HashMap::new(),
36 };
37
38 rv.address_map
40 .insert(UnixSocketType::Stream, HashMap::new());
41 rv.address_map.insert(UnixSocketType::Dgram, HashMap::new());
42 rv.address_map
43 .insert(UnixSocketType::SeqPacket, HashMap::new());
44
45 rv
46 }
47
48 pub fn lookup(
49 &self,
50 sock_type: UnixSocketType,
51 name: &[u8],
52 ) -> Option<Arc<AtomicRefCell<UnixSocket>>> {
53 self.address_map
58 .get(&sock_type)
59 .unwrap()
60 .get(name)
61 .map(|x| x.socket.upgrade().unwrap())
62 }
63
64 pub fn bind(
65 ns_arc: &Arc<AtomicRefCell<Self>>,
66 sock_type: UnixSocketType,
67 mut name: Vec<u8>,
68 socket: &Arc<AtomicRefCell<UnixSocket>>,
69 socket_event_source: &mut StateEventSource,
70 ) -> Result<(), BindError> {
71 name.shrink_to_fit();
73
74 let mut ns = ns_arc.borrow_mut();
75 let name_copy = name.clone();
76
77 let entry = match ns.address_map.get_mut(&sock_type).unwrap().entry(name) {
79 std::collections::hash_map::Entry::Occupied(_) => return Err(BindError::NameInUse),
80 std::collections::hash_map::Entry::Vacant(x) => x,
81 };
82
83 let handle =
85 Self::on_socket_close(Arc::downgrade(ns_arc), socket_event_source, move |ns| {
86 assert!(ns.unbind(sock_type, &name_copy).is_ok());
87 });
88
89 entry.insert(NamespaceEntry::new(Arc::downgrade(socket), handle));
90
91 Ok(())
92 }
93
94 pub fn autobind(
95 ns_arc: &Arc<AtomicRefCell<Self>>,
96 sock_type: UnixSocketType,
97 socket: &Arc<AtomicRefCell<UnixSocket>>,
98 socket_event_source: &mut StateEventSource,
99 mut rng: impl rand::Rng,
100 ) -> Result<Vec<u8>, BindError> {
101 let mut ns = ns_arc.borrow_mut();
102
103 let mut name = None;
105
106 for _ in 0..10 {
108 let random_name: [u8; NAME_LEN] = random_name(&mut rng);
109
110 if !ns
111 .address_map
112 .get(&sock_type)
113 .unwrap()
114 .contains_key(&random_name[..])
115 {
116 name = Some(random_name.to_vec());
117 break;
118 }
119 }
120
121 if name.is_none() {
123 for x in 0..CHARSET.len().pow(NAME_LEN as u32) {
124 let temp_name: [u8; NAME_LEN] = incremental_name(x);
125
126 if !ns
127 .address_map
128 .get(&sock_type)
129 .unwrap()
130 .contains_key(&temp_name[..])
131 {
132 name = Some(temp_name.to_vec());
133 break;
134 }
135 }
136 }
137
138 let name = match name {
139 Some(x) => x,
140 None => return Err(BindError::NoNamesAvailable),
142 };
143
144 let name_copy = name.clone();
145
146 let handle =
148 Self::on_socket_close(Arc::downgrade(ns_arc), socket_event_source, move |ns| {
149 assert!(ns.unbind(sock_type, &name_copy).is_ok());
150 });
151
152 if let std::collections::hash_map::Entry::Vacant(entry) = ns
153 .address_map
154 .get_mut(&sock_type)
155 .unwrap()
156 .entry(name.clone())
157 {
158 entry.insert(NamespaceEntry::new(Arc::downgrade(socket), handle));
159 } else {
160 unreachable!();
161 }
162
163 Ok(name)
164 }
165
166 pub fn unbind(&mut self, sock_type: UnixSocketType, name: &Vec<u8>) -> Result<(), BindError> {
167 if self
170 .address_map
171 .get_mut(&sock_type)
172 .unwrap()
173 .remove(name)
174 .is_none()
175 {
176 return Err(BindError::NameNotFound);
178 }
179
180 Ok(())
181 }
182
183 fn on_socket_close(
185 ns: Weak<AtomicRefCell<Self>>,
186 event_source: &mut StateEventSource,
187 f: impl Fn(&mut Self) + Send + Sync + 'static,
188 ) -> StateListenHandle {
189 event_source.add_listener(
190 FileState::CLOSED,
191 FileSignals::empty(),
192 StateListenerFilter::OffToOn,
193 move |state, _changed, _signals, _cb_queue| {
194 assert!(state.contains(FileState::CLOSED));
195 if let Some(ns) = ns.upgrade() {
196 f(&mut ns.borrow_mut());
197 }
198 },
199 )
200 }
201}
202
203impl Default for AbstractUnixNamespace {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209#[derive(Debug, Clone, Copy)]
210pub enum BindError {
211 NameInUse,
213 NoNamesAvailable,
215 NameNotFound,
217}
218
219impl std::error::Error for BindError {}
220
221impl std::fmt::Display for BindError {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 match self {
224 Self::NameInUse => write!(f, "Name is already in use"),
225 Self::NoNamesAvailable => {
226 write!(f, "Names in the ephemeral name range are all in use")
227 }
228 Self::NameNotFound => write!(f, "Name was not found in the address map"),
229 }
230 }
231}
232
233const CHARSET: &[u8] = b"abcdef0123456789";
236const NAME_LEN: usize = 5;
237
238fn random_name<const L: usize>(mut rng: impl rand::Rng) -> [u8; L] {
240 let mut name = [0u8; L];
241
242 for c in &mut name {
244 *c = *CHARSET.choose(&mut rng).unwrap();
245 }
246
247 name
248}
249
250fn incremental_name<const L: usize>(mut index: usize) -> [u8; L] {
254 const CHARSET_LEN: usize = CHARSET.len();
255
256 assert!(index < CHARSET_LEN.pow(L as u32));
258
259 let mut name = [0u8; L];
260
261 for x in 0..L {
263 let charset_index = index % CHARSET_LEN;
265 index /= CHARSET_LEN;
266
267 name[L - x - 1] = CHARSET[charset_index];
269 }
270
271 name
272}
273
274#[cfg(test)]
275mod tests {
276 use rand_core::SeedableRng;
277 use rand_xoshiro::Xoshiro256PlusPlus;
278
279 use super::*;
280
281 #[test]
282 fn test_random_name() {
283 let mut rng = Xoshiro256PlusPlus::seed_from_u64(0);
284
285 let name_1: [u8; 5] = random_name(&mut rng);
286 let name_2: [u8; 5] = random_name(&mut rng);
287
288 assert!(name_1.iter().all(|x| CHARSET.contains(x)));
289 assert!(name_2.iter().all(|x| CHARSET.contains(x)));
290 assert_ne!(name_1, name_2);
291 }
292
293 #[test]
294 fn test_incremental_name() {
295 assert_eq!(incremental_name::<5>(0), [b'a', b'a', b'a', b'a', b'a']);
296 assert_eq!(incremental_name::<5>(1), [b'a', b'a', b'a', b'a', b'b']);
297 assert_eq!(
298 incremental_name::<5>(CHARSET.len()),
299 [b'a', b'a', b'a', b'b', b'a']
300 );
301 assert_eq!(
302 incremental_name::<5>(CHARSET.len() + 1),
303 [b'a', b'a', b'a', b'b', b'b']
304 );
305 assert_eq!(
306 incremental_name::<5>(CHARSET.len().pow(5) - 1),
307 [b'9', b'9', b'9', b'9', b'9']
308 );
309 }
310
311 #[test]
312 #[should_panic]
313 fn test_incremental_name_panic() {
314 incremental_name::<5>(CHARSET.len().pow(5));
315 }
316}