shadow_rs/host/descriptor/
descriptor_table.rs

1use std::collections::{BTreeSet, HashMap};
2
3use log::*;
4use shadow_shim_helper_rs::explicit_drop::ExplicitDrop;
5use shadow_shim_helper_rs::syscall_types::SyscallReg;
6
7use crate::host::descriptor::Descriptor;
8use crate::host::host::Host;
9use crate::utility::ObjectCounter;
10use crate::utility::callback_queue::CallbackQueue;
11
12/// POSIX requires fds to be assigned as `libc::c_int`, so we can't allow any fds larger than this.
13pub const FD_MAX: u32 = i32::MAX as u32;
14
15/// Map of file handles to file descriptors. Typically owned by a
16/// [`Thread`][crate::host::thread::Thread].
17#[derive(Clone)]
18pub struct DescriptorTable {
19    descriptors: HashMap<DescriptorHandle, Descriptor>,
20
21    // Indices less than `next_index` known to be available.
22    available_indices: BTreeSet<u32>,
23
24    // Lowest index not in `available_indices` that *might* be available. We still need to verify
25    // availability in `descriptors`, though.
26    next_index: u32,
27
28    _counter: ObjectCounter,
29}
30
31impl DescriptorTable {
32    pub fn new() -> Self {
33        DescriptorTable {
34            descriptors: HashMap::new(),
35            available_indices: BTreeSet::new(),
36            next_index: 0,
37            _counter: ObjectCounter::new("DescriptorTable"),
38        }
39    }
40
41    /// Add the descriptor at an unused index, and return the index. If the descriptor could not be
42    /// added, the descriptor is returned in the `Err`.
43    fn add(
44        &mut self,
45        descriptor: Descriptor,
46        min_index: DescriptorHandle,
47    ) -> Result<DescriptorHandle, Descriptor> {
48        let idx = if let Some(idx) = self.available_indices.range(min_index.val()..).next() {
49            // Un-borrow from `available_indices`.
50            let idx = *idx;
51            // Take from `available_indices`
52            trace!("Reusing available index {}", idx);
53            self.available_indices.remove(&idx);
54            idx
55        } else {
56            // Start our search at either the next likely available index or the minimum index,
57            // whichever is larger.
58            let mut idx = std::cmp::max(self.next_index, min_index.val());
59
60            // Check if this index out of range.
61            if idx > FD_MAX {
62                return Err(descriptor);
63            }
64
65            // Only update next_index if we started at it, otherwise there may be other
66            // available indexes lower than idx.
67            let should_update_next_index = idx == self.next_index;
68
69            // Skip past any indexes that are in use. This can happen after
70            // calling `set` with a value greater than `next_index`.
71            while self
72                .descriptors
73                .contains_key(&DescriptorHandle::new(idx).unwrap())
74            {
75                trace!("Skipping past in-use index {}", idx);
76
77                // Check if the next index is out of range.
78                if idx >= FD_MAX {
79                    return Err(descriptor);
80                }
81
82                // Won't overflow because of the check above.
83                idx += 1;
84            }
85
86            if should_update_next_index {
87                self.next_index = idx + 1;
88            }
89
90            // Take the next index.
91            trace!("Using index {}", idx);
92            idx
93        };
94
95        let idx = DescriptorHandle::new(idx).unwrap();
96
97        let prev = self.descriptors.insert(idx, descriptor);
98        assert!(prev.is_none(), "Already a descriptor at {}", idx);
99
100        Ok(idx)
101    }
102
103    // Call after inserting to `available_indices`, to free any that are contiguous
104    // with `next_index`.
105    fn trim_tail(&mut self) {
106        while let Some(last_in_available) = self.available_indices.iter().next_back().copied() {
107            if (last_in_available + 1) == self.next_index {
108                // Last entry in available_indices is adjacent to next_index.
109                // We can merge them, freeing an entry in `available_indices`.
110                self.next_index -= 1;
111                self.available_indices.remove(&last_in_available);
112            } else {
113                break;
114            }
115        }
116    }
117
118    /// Get the descriptor at `idx`, if any.
119    pub fn get(&self, idx: DescriptorHandle) -> Option<&Descriptor> {
120        self.descriptors.get(&idx)
121    }
122
123    /// Get the descriptor at `idx`, if any.
124    pub fn get_mut(&mut self, idx: DescriptorHandle) -> Option<&mut Descriptor> {
125        self.descriptors.get_mut(&idx)
126    }
127
128    /// Insert a descriptor at `index`. If a descriptor is already present at that index, it is
129    /// unregistered from that index and returned.
130    #[must_use]
131    fn set(&mut self, index: DescriptorHandle, descriptor: Descriptor) -> Option<Descriptor> {
132        // We ensure the index is no longer in `self.available_indices`. We *don't* ensure
133        // `self.next_index` is > `index`, since that'd require adding the indices in between to
134        // `self.available_indices`. It uses less memory and is no more expensive to iterate when
135        // *using* `self.available_indices` instead.
136        self.available_indices.remove(&index.val());
137
138        let prev = self.descriptors.insert(index, descriptor);
139
140        if prev.is_some() {
141            trace!("Overwriting index {}", index);
142        } else {
143            trace!("Setting to unused index {}", index);
144        }
145
146        prev
147    }
148
149    /// Register a descriptor and return its fd handle. Equivalent to
150    /// [`register_descriptor_with_min_fd(desc, 0)`][Self::register_descriptor_with_min_fd]. If the
151    /// descriptor could not be added, the descriptor is returned in the `Err`.
152    pub fn register_descriptor(
153        &mut self,
154        desc: Descriptor,
155    ) -> Result<DescriptorHandle, Descriptor> {
156        const ZERO: DescriptorHandle = match DescriptorHandle::new(0) {
157            Some(x) => x,
158            None => unreachable!(),
159        };
160        self.add(desc, ZERO)
161    }
162
163    /// Register a descriptor and return its fd handle. If the descriptor could not be added, the
164    /// descriptor is returned in the `Err`.
165    pub fn register_descriptor_with_min_fd(
166        &mut self,
167        desc: Descriptor,
168        min_fd: DescriptorHandle,
169    ) -> Result<DescriptorHandle, Descriptor> {
170        self.add(desc, min_fd)
171    }
172
173    /// Register a descriptor with a given fd handle and return the descriptor that it replaced.
174    #[must_use]
175    pub fn register_descriptor_with_fd(
176        &mut self,
177        desc: Descriptor,
178        new_fd: DescriptorHandle,
179    ) -> Option<Descriptor> {
180        self.set(new_fd, desc)
181    }
182
183    /// Deregister the descriptor with the given fd handle and return it.
184    #[must_use]
185    pub fn deregister_descriptor(&mut self, fd: DescriptorHandle) -> Option<Descriptor> {
186        let maybe_descriptor = self.descriptors.remove(&fd);
187        self.available_indices.insert(fd.val());
188        self.trim_tail();
189        maybe_descriptor
190    }
191
192    /// Remove and return all descriptors.
193    pub fn remove_all(&mut self) -> impl Iterator<Item = Descriptor> {
194        // reset the descriptor table
195        let old_self = std::mem::replace(self, Self::new());
196        // return the old descriptors
197        old_self.descriptors.into_values()
198    }
199
200    /// Remove and return all descriptors in the range. If you want to remove all descriptors, you
201    /// should use [`remove_all`](Self::remove_all).
202    pub fn remove_range(
203        &mut self,
204        range: impl std::ops::RangeBounds<DescriptorHandle>,
205    ) -> impl Iterator<Item = Descriptor> {
206        // This code is not very efficient but it shouldn't be called often, so it should be fine
207        // for now. If we wanted something more efficient, we'd need to redesign the descriptor
208        // table to not use a hash map.
209
210        let fds: Vec<_> = self
211            .iter()
212            .filter_map(|(fd, _)| range.contains(fd).then_some(*fd))
213            .collect();
214
215        let mut descriptors = Vec::with_capacity(fds.len());
216        for fd in fds {
217            descriptors.push(self.deregister_descriptor(fd).unwrap());
218        }
219
220        descriptors.into_iter()
221    }
222
223    pub fn iter(&self) -> impl Iterator<Item = (&DescriptorHandle, &Descriptor)> {
224        self.descriptors.iter()
225    }
226
227    pub fn iter_mut(&mut self) -> impl Iterator<Item = (&DescriptorHandle, &mut Descriptor)> {
228        self.descriptors.iter_mut()
229    }
230}
231
232impl Default for DescriptorTable {
233    fn default() -> Self {
234        Self::new()
235    }
236}
237
238impl ExplicitDrop for DescriptorTable {
239    type ExplicitDropParam = Host;
240    type ExplicitDropResult = ();
241
242    fn explicit_drop(mut self, host: &Host) {
243        // Drop all descriptors using a callback queue.
244        //
245        // Doing this explicitly instead of letting `DescriptorTable`'s `Drop`
246        // implementation implicitly close these individually is a performance
247        // optimization so that all descriptors are closed before any of their
248        // callbacks run.
249        let descriptors = self.remove_all();
250        CallbackQueue::queue_and_run_with_legacy(|cb_queue| {
251            for desc in descriptors {
252                desc.close(host, cb_queue);
253            }
254        });
255    }
256}
257
258/// A handle for a file descriptor.
259#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
260pub struct DescriptorHandle(u32);
261
262impl DescriptorHandle {
263    /// Returns `Some` if `fd` is less than [`FD_MAX`]. Can be used in `const` contexts.
264    pub const fn new(fd: u32) -> Option<Self> {
265        if fd > FD_MAX {
266            return None;
267        }
268
269        Some(DescriptorHandle(fd))
270    }
271
272    pub fn val(&self) -> u32 {
273        self.0
274    }
275}
276
277impl std::fmt::Display for DescriptorHandle {
278    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        self.0.fmt(f)
280    }
281}
282
283impl From<DescriptorHandle> for u32 {
284    fn from(x: DescriptorHandle) -> u32 {
285        x.0
286    }
287}
288
289impl From<DescriptorHandle> for u64 {
290    fn from(x: DescriptorHandle) -> u64 {
291        x.0.into()
292    }
293}
294
295impl From<DescriptorHandle> for i32 {
296    fn from(x: DescriptorHandle) -> i32 {
297        const { assert!(FD_MAX <= i32::MAX as u32) };
298        // the constructor makes sure this won't panic
299        x.0.try_into().unwrap()
300    }
301}
302
303impl From<DescriptorHandle> for i64 {
304    fn from(x: DescriptorHandle) -> i64 {
305        x.0.into()
306    }
307}
308
309impl From<DescriptorHandle> for SyscallReg {
310    fn from(x: DescriptorHandle) -> SyscallReg {
311        x.0.into()
312    }
313}
314
315impl TryFrom<u32> for DescriptorHandle {
316    type Error = DescriptorHandleError;
317    fn try_from(x: u32) -> Result<Self, Self::Error> {
318        DescriptorHandle::new(x).ok_or(DescriptorHandleError())
319    }
320}
321
322impl TryFrom<u64> for DescriptorHandle {
323    // use the same error type as the conversion from u32
324    type Error = <DescriptorHandle as TryFrom<u32>>::Error;
325    fn try_from(x: u64) -> Result<Self, Self::Error> {
326        u32::try_from(x)
327            .or(Err(DescriptorHandleError()))?
328            .try_into()
329    }
330}
331
332impl TryFrom<i32> for DescriptorHandle {
333    type Error = DescriptorHandleError;
334    fn try_from(x: i32) -> Result<Self, Self::Error> {
335        x.try_into()
336            .ok()
337            .and_then(DescriptorHandle::new)
338            .ok_or(DescriptorHandleError())
339    }
340}
341
342impl TryFrom<i64> for DescriptorHandle {
343    // use the same error type as the conversion from i32
344    type Error = <DescriptorHandle as TryFrom<i32>>::Error;
345    fn try_from(x: i64) -> Result<Self, Self::Error> {
346        i32::try_from(x)
347            .or(Err(DescriptorHandleError()))?
348            .try_into()
349    }
350}
351
352/// The handle is not valid.
353#[derive(Copy, Clone, Debug, PartialEq, Eq)]
354pub struct DescriptorHandleError();
355
356impl std::fmt::Display for DescriptorHandleError {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        write!(f, "Not a valid descriptor handle")
359    }
360}
361
362impl std::error::Error for DescriptorHandleError {}