shadow_rs/host/descriptor/
descriptor_table.rs1use std::collections::{BTreeMap, BTreeSet};
2
3use log::*;
4use shadow_shim_helper_rs::explicit_drop::ExplicitDrop;
5use shadow_shim_helper_rs::syscall_types::SyscallReg;
6
7use crate::host::descriptor::Descriptor;
8use crate::host::host::Host;
9use crate::utility::ObjectCounter;
10use crate::utility::callback_queue::CallbackQueue;
11
12pub const FD_MAX: u32 = i32::MAX as u32;
14
15#[derive(Clone)]
18pub struct DescriptorTable {
19 descriptors: BTreeMap<DescriptorHandle, Descriptor>,
20
21 available_indices: BTreeSet<u32>,
23
24 next_index: u32,
27
28 _counter: ObjectCounter,
29}
30
31impl DescriptorTable {
32 pub fn new() -> Self {
33 DescriptorTable {
34 descriptors: Default::default(),
35 available_indices: BTreeSet::new(),
36 next_index: 0,
37 _counter: ObjectCounter::new("DescriptorTable"),
38 }
39 }
40
41 fn add(
44 &mut self,
45 descriptor: Descriptor,
46 min_index: DescriptorHandle,
47 ) -> Result<DescriptorHandle, Descriptor> {
48 let idx = if let Some(idx) = self.available_indices.range(min_index.val()..).next() {
49 let idx = *idx;
51 trace!("Reusing available index {idx}");
53 self.available_indices.remove(&idx);
54 idx
55 } else {
56 let mut idx = std::cmp::max(self.next_index, min_index.val());
59
60 if idx > FD_MAX {
62 return Err(descriptor);
63 }
64
65 let should_update_next_index = idx == self.next_index;
68
69 while self
72 .descriptors
73 .contains_key(&DescriptorHandle::new(idx).unwrap())
74 {
75 trace!("Skipping past in-use index {idx}");
76
77 if idx >= FD_MAX {
79 return Err(descriptor);
80 }
81
82 idx += 1;
84 }
85
86 if should_update_next_index {
87 self.next_index = idx + 1;
88 }
89
90 trace!("Using index {idx}");
92 idx
93 };
94
95 let idx = DescriptorHandle::new(idx).unwrap();
96
97 let prev = self.descriptors.insert(idx, descriptor);
98 assert!(prev.is_none(), "Already a descriptor at {idx}");
99
100 Ok(idx)
101 }
102
103 fn trim_tail(&mut self) {
106 while let Some(last_in_available) = self.available_indices.iter().next_back().copied() {
107 if (last_in_available + 1) == self.next_index {
108 self.next_index -= 1;
111 self.available_indices.remove(&last_in_available);
112 } else {
113 break;
114 }
115 }
116 }
117
118 pub fn get(&self, idx: DescriptorHandle) -> Option<&Descriptor> {
120 self.descriptors.get(&idx)
121 }
122
123 pub fn get_mut(&mut self, idx: DescriptorHandle) -> Option<&mut Descriptor> {
125 self.descriptors.get_mut(&idx)
126 }
127
128 #[must_use]
131 fn set(&mut self, index: DescriptorHandle, descriptor: Descriptor) -> Option<Descriptor> {
132 self.available_indices.remove(&index.val());
137
138 let prev = self.descriptors.insert(index, descriptor);
139
140 if prev.is_some() {
141 trace!("Overwriting index {index}");
142 } else {
143 trace!("Setting to unused index {index}");
144 }
145
146 prev
147 }
148
149 pub fn register_descriptor(
153 &mut self,
154 desc: Descriptor,
155 ) -> Result<DescriptorHandle, Descriptor> {
156 const ZERO: DescriptorHandle = match DescriptorHandle::new(0) {
157 Some(x) => x,
158 None => unreachable!(),
159 };
160 self.add(desc, ZERO)
161 }
162
163 pub fn register_descriptor_with_min_fd(
166 &mut self,
167 desc: Descriptor,
168 min_fd: DescriptorHandle,
169 ) -> Result<DescriptorHandle, Descriptor> {
170 self.add(desc, min_fd)
171 }
172
173 #[must_use]
175 pub fn register_descriptor_with_fd(
176 &mut self,
177 desc: Descriptor,
178 new_fd: DescriptorHandle,
179 ) -> Option<Descriptor> {
180 self.set(new_fd, desc)
181 }
182
183 #[must_use]
185 pub fn deregister_descriptor(&mut self, fd: DescriptorHandle) -> Option<Descriptor> {
186 let maybe_descriptor = self.descriptors.remove(&fd);
187 self.available_indices.insert(fd.val());
188 self.trim_tail();
189 maybe_descriptor
190 }
191
192 pub fn remove_all(&mut self) -> impl Iterator<Item = Descriptor> {
194 let old_self = std::mem::take(self);
196 old_self.descriptors.into_values()
198 }
199
200 pub fn remove_range(
203 &mut self,
204 range: impl std::ops::RangeBounds<DescriptorHandle>,
205 ) -> impl Iterator<Item = Descriptor> {
206 let fds: Vec<_> = self
211 .iter()
212 .filter_map(|(fd, _)| range.contains(fd).then_some(*fd))
213 .collect();
214
215 let mut descriptors = Vec::with_capacity(fds.len());
216 for fd in fds {
217 descriptors.push(self.deregister_descriptor(fd).unwrap());
218 }
219
220 descriptors.into_iter()
221 }
222
223 pub fn iter(&self) -> impl Iterator<Item = (&DescriptorHandle, &Descriptor)> {
224 self.descriptors.iter()
225 }
226
227 pub fn iter_mut(&mut self) -> impl Iterator<Item = (&DescriptorHandle, &mut Descriptor)> {
228 self.descriptors.iter_mut()
229 }
230}
231
232impl Default for DescriptorTable {
233 fn default() -> Self {
234 Self::new()
235 }
236}
237
238impl ExplicitDrop for DescriptorTable {
239 type ExplicitDropParam = Host;
240 type ExplicitDropResult = ();
241
242 fn explicit_drop(mut self, host: &Host) {
243 let descriptors = self.remove_all();
250 CallbackQueue::queue_and_run_with_legacy(|cb_queue| {
251 for desc in descriptors {
252 desc.close(host, cb_queue);
253 }
254 });
255 }
256}
257
258#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
260pub struct DescriptorHandle(u32);
261
262impl DescriptorHandle {
263 pub const fn new(fd: u32) -> Option<Self> {
265 if fd > FD_MAX {
266 return None;
267 }
268
269 Some(DescriptorHandle(fd))
270 }
271
272 pub fn val(&self) -> u32 {
273 self.0
274 }
275}
276
277impl std::fmt::Display for DescriptorHandle {
278 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279 self.0.fmt(f)
280 }
281}
282
283impl From<DescriptorHandle> for u32 {
284 fn from(x: DescriptorHandle) -> u32 {
285 x.0
286 }
287}
288
289impl From<DescriptorHandle> for u64 {
290 fn from(x: DescriptorHandle) -> u64 {
291 x.0.into()
292 }
293}
294
295impl From<DescriptorHandle> for i32 {
296 fn from(x: DescriptorHandle) -> i32 {
297 const { assert!(FD_MAX <= i32::MAX as u32) };
298 x.0.try_into().unwrap()
300 }
301}
302
303impl From<DescriptorHandle> for i64 {
304 fn from(x: DescriptorHandle) -> i64 {
305 x.0.into()
306 }
307}
308
309impl From<DescriptorHandle> for SyscallReg {
310 fn from(x: DescriptorHandle) -> SyscallReg {
311 x.0.into()
312 }
313}
314
315impl TryFrom<u32> for DescriptorHandle {
316 type Error = DescriptorHandleError;
317 fn try_from(x: u32) -> Result<Self, Self::Error> {
318 DescriptorHandle::new(x).ok_or(DescriptorHandleError())
319 }
320}
321
322impl TryFrom<u64> for DescriptorHandle {
323 type Error = <DescriptorHandle as TryFrom<u32>>::Error;
325 fn try_from(x: u64) -> Result<Self, Self::Error> {
326 u32::try_from(x)
327 .or(Err(DescriptorHandleError()))?
328 .try_into()
329 }
330}
331
332impl TryFrom<i32> for DescriptorHandle {
333 type Error = DescriptorHandleError;
334 fn try_from(x: i32) -> Result<Self, Self::Error> {
335 x.try_into()
336 .ok()
337 .and_then(DescriptorHandle::new)
338 .ok_or(DescriptorHandleError())
339 }
340}
341
342impl TryFrom<i64> for DescriptorHandle {
343 type Error = <DescriptorHandle as TryFrom<i32>>::Error;
345 fn try_from(x: i64) -> Result<Self, Self::Error> {
346 i32::try_from(x)
347 .or(Err(DescriptorHandleError()))?
348 .try_into()
349 }
350}
351
352#[derive(Copy, Clone, Debug, PartialEq, Eq)]
354pub struct DescriptorHandleError();
355
356impl std::fmt::Display for DescriptorHandleError {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 write!(f, "Not a valid descriptor handle")
359 }
360}
361
362impl std::error::Error for DescriptorHandleError {}