shadow_rs/host/descriptor/socket/
abstract_unix_ns.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Weak};
3
4use atomic_refcell::AtomicRefCell;
5use rand::seq::IndexedRandom;
6
7use crate::host::descriptor::listener::{StateEventSource, StateListenHandle, StateListenerFilter};
8use crate::host::descriptor::socket::unix::{UnixSocket, UnixSocketType};
9use crate::host::descriptor::{FileSignals, FileState};
10
11struct NamespaceEntry {
12    /// The bound socket.
13    socket: Weak<AtomicRefCell<UnixSocket>>,
14    /// The event listener handle, which removes the listener when dropped.
15    _handle: StateListenHandle,
16}
17
18impl NamespaceEntry {
19    pub fn new(socket: Weak<AtomicRefCell<UnixSocket>>, handle: StateListenHandle) -> Self {
20        Self {
21            socket,
22            _handle: handle,
23        }
24    }
25}
26
27pub struct AbstractUnixNamespace {
28    address_map: HashMap<UnixSocketType, HashMap<Vec<u8>, NamespaceEntry>>,
29}
30
31impl AbstractUnixNamespace {
32    pub fn new() -> Self {
33        let mut rv = Self {
34            // initializes an empty hash map for each unix socket type
35            address_map: HashMap::new(),
36        };
37
38        // the namespace code will assume that there is an entry for each socket type
39        rv.address_map
40            .insert(UnixSocketType::Stream, HashMap::new());
41        rv.address_map.insert(UnixSocketType::Dgram, HashMap::new());
42        rv.address_map
43            .insert(UnixSocketType::SeqPacket, HashMap::new());
44
45        rv
46    }
47
48    pub fn lookup(
49        &self,
50        sock_type: UnixSocketType,
51        name: &[u8],
52    ) -> Option<Arc<AtomicRefCell<UnixSocket>>> {
53        // the unwrap() will panic if the socket was dropped without being closed, but this should
54        // only be possible at the end of the simulation and there wouldn't be any reason to call
55        // lookup() at that time, so a panic here would most likely indicate an issue somewhere else
56        // in shadow
57        self.address_map
58            .get(&sock_type)
59            .unwrap()
60            .get(name)
61            .map(|x| x.socket.upgrade().unwrap())
62    }
63
64    pub fn bind(
65        ns_arc: &Arc<AtomicRefCell<Self>>,
66        sock_type: UnixSocketType,
67        mut name: Vec<u8>,
68        socket: &Arc<AtomicRefCell<UnixSocket>>,
69        socket_event_source: &mut StateEventSource,
70    ) -> Result<(), BindError> {
71        // make sure we aren't wasting memory since we don't mutate the name
72        name.shrink_to_fit();
73
74        let mut ns = ns_arc.borrow_mut();
75        let name_copy = name.clone();
76
77        // look up the name in the address map
78        let entry = match ns.address_map.get_mut(&sock_type).unwrap().entry(name) {
79            std::collections::hash_map::Entry::Occupied(_) => return Err(BindError::NameInUse),
80            std::collections::hash_map::Entry::Vacant(x) => x,
81        };
82
83        // when the socket closes, remove this entry from the namespace
84        let handle =
85            Self::on_socket_close(Arc::downgrade(ns_arc), socket_event_source, move |ns| {
86                assert!(ns.unbind(sock_type, &name_copy).is_ok());
87            });
88
89        entry.insert(NamespaceEntry::new(Arc::downgrade(socket), handle));
90
91        Ok(())
92    }
93
94    pub fn autobind(
95        ns_arc: &Arc<AtomicRefCell<Self>>,
96        sock_type: UnixSocketType,
97        socket: &Arc<AtomicRefCell<UnixSocket>>,
98        socket_event_source: &mut StateEventSource,
99        mut rng: impl rand::Rng,
100    ) -> Result<Vec<u8>, BindError> {
101        let mut ns = ns_arc.borrow_mut();
102
103        // the unused name that we will bind the socket to
104        let mut name = None;
105
106        // try 10 random names
107        for _ in 0..10 {
108            let random_name: [u8; NAME_LEN] = random_name(&mut rng);
109
110            if !ns
111                .address_map
112                .get(&sock_type)
113                .unwrap()
114                .contains_key(&random_name[..])
115            {
116                name = Some(random_name.to_vec());
117                break;
118            }
119        }
120
121        // if unsuccessful, try a linear search through all valid names
122        if name.is_none() {
123            for x in 0..CHARSET.len().pow(NAME_LEN as u32) {
124                let temp_name: [u8; NAME_LEN] = incremental_name(x);
125
126                if !ns
127                    .address_map
128                    .get(&sock_type)
129                    .unwrap()
130                    .contains_key(&temp_name[..])
131                {
132                    name = Some(temp_name.to_vec());
133                    break;
134                }
135            }
136        }
137
138        let name = match name {
139            Some(x) => x,
140            // every valid name has been taken
141            None => return Err(BindError::NoNamesAvailable),
142        };
143
144        let name_copy = name.clone();
145
146        // when the socket closes, remove this entry from the namespace
147        let handle =
148            Self::on_socket_close(Arc::downgrade(ns_arc), socket_event_source, move |ns| {
149                assert!(ns.unbind(sock_type, &name_copy).is_ok());
150            });
151
152        if let std::collections::hash_map::Entry::Vacant(entry) = ns
153            .address_map
154            .get_mut(&sock_type)
155            .unwrap()
156            .entry(name.clone())
157        {
158            entry.insert(NamespaceEntry::new(Arc::downgrade(socket), handle));
159        } else {
160            unreachable!();
161        }
162
163        Ok(name)
164    }
165
166    pub fn unbind(&mut self, sock_type: UnixSocketType, name: &Vec<u8>) -> Result<(), BindError> {
167        // remove the namespace entry which includes the handle, so the event listener will
168        // automatically be removed from the socket
169        if self
170            .address_map
171            .get_mut(&sock_type)
172            .unwrap()
173            .remove(name)
174            .is_none()
175        {
176            // didn't exist in the address map
177            return Err(BindError::NameNotFound);
178        }
179
180        Ok(())
181    }
182
183    /// Adds a listener to the socket which runs the callback `f` when the socket is closed.
184    fn on_socket_close(
185        ns: Weak<AtomicRefCell<Self>>,
186        event_source: &mut StateEventSource,
187        f: impl Fn(&mut Self) + Send + Sync + 'static,
188    ) -> StateListenHandle {
189        event_source.add_listener(
190            FileState::CLOSED,
191            FileSignals::empty(),
192            StateListenerFilter::OffToOn,
193            move |state, _changed, _signals, _cb_queue| {
194                assert!(state.contains(FileState::CLOSED));
195                if let Some(ns) = ns.upgrade() {
196                    f(&mut ns.borrow_mut());
197                }
198            },
199        )
200    }
201}
202
203impl Default for AbstractUnixNamespace {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209#[derive(Debug, Clone, Copy)]
210pub enum BindError {
211    /// The name is already in use.
212    NameInUse,
213    /// Names in the ephemeral name range are all in use.
214    NoNamesAvailable,
215    /// The name was not found in the address map.
216    NameNotFound,
217}
218
219impl std::error::Error for BindError {}
220
221impl std::fmt::Display for BindError {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        match self {
224            Self::NameInUse => write!(f, "Name is already in use"),
225            Self::NoNamesAvailable => {
226                write!(f, "Names in the ephemeral name range are all in use")
227            }
228            Self::NameNotFound => write!(f, "Name was not found in the address map"),
229        }
230    }
231}
232
233/// The characters that are valid in auto-generated names; see subsection "Autobind feature" in
234/// unix(7).
235const CHARSET: &[u8] = b"abcdef0123456789";
236const NAME_LEN: usize = 5;
237
238/// Choose a random name of length `L`.
239fn random_name<const L: usize>(mut rng: impl rand::Rng) -> [u8; L] {
240    let mut name = [0u8; L];
241
242    // set each character of the name
243    for c in &mut name {
244        *c = *CHARSET.choose(&mut rng).unwrap();
245    }
246
247    name
248}
249
250/// Get a name in the set of all valid names. This is essentially the n'th element in the cartesian
251/// power of set `CHARSET`. This would be better implemented as a generator when generators become
252/// stable.
253fn incremental_name<const L: usize>(mut index: usize) -> [u8; L] {
254    const CHARSET_LEN: usize = CHARSET.len();
255
256    // there are a limited number of valid names
257    assert!(index < CHARSET_LEN.pow(L as u32));
258
259    let mut name = [0u8; L];
260
261    // set each character of the name
262    for x in 0..L {
263        // take the base-10 index and convert it to base-CHARSET_LEN digits
264        let charset_index = index % CHARSET_LEN;
265        index /= CHARSET_LEN;
266
267        // set the name in reverse order
268        name[L - x - 1] = CHARSET[charset_index];
269    }
270
271    name
272}
273
274#[cfg(test)]
275mod tests {
276    use rand_core::SeedableRng;
277    use rand_xoshiro::Xoshiro256PlusPlus;
278
279    use super::*;
280
281    #[test]
282    fn test_random_name() {
283        let mut rng = Xoshiro256PlusPlus::seed_from_u64(0);
284
285        let name_1: [u8; 5] = random_name(&mut rng);
286        let name_2: [u8; 5] = random_name(&mut rng);
287
288        assert!(name_1.iter().all(|x| CHARSET.contains(x)));
289        assert!(name_2.iter().all(|x| CHARSET.contains(x)));
290        assert_ne!(name_1, name_2);
291    }
292
293    #[test]
294    fn test_incremental_name() {
295        assert_eq!(incremental_name::<5>(0), [b'a', b'a', b'a', b'a', b'a']);
296        assert_eq!(incremental_name::<5>(1), [b'a', b'a', b'a', b'a', b'b']);
297        assert_eq!(
298            incremental_name::<5>(CHARSET.len()),
299            [b'a', b'a', b'a', b'b', b'a']
300        );
301        assert_eq!(
302            incremental_name::<5>(CHARSET.len() + 1),
303            [b'a', b'a', b'a', b'b', b'b']
304        );
305        assert_eq!(
306            incremental_name::<5>(CHARSET.len().pow(5) - 1),
307            [b'9', b'9', b'9', b'9', b'9']
308        );
309    }
310
311    #[test]
312    #[should_panic]
313    fn test_incremental_name_panic() {
314        incremental_name::<5>(CHARSET.len().pow(5));
315    }
316}