shadow_shim_helper_rs/rootedcell/
rc.rs1extern crate alloc;
2
3use core::{
4    cell::{Cell, UnsafeCell},
5    ptr::NonNull,
6};
7
8use alloc::boxed::Box;
9
10use crate::explicit_drop::ExplicitDrop;
11
12use super::{Root, Tag};
13
14struct RootedRcInternal<T> {
15    val: UnsafeCell<Option<T>>,
16    strong_count: Cell<u32>,
17    weak_count: Cell<u32>,
18}
19
20impl<T> RootedRcInternal<T> {
21    pub fn new(val: T) -> Self {
22        Self {
23            val: UnsafeCell::new(Some(val)),
24            strong_count: Cell::new(1),
25            weak_count: Cell::new(0),
26        }
27    }
28
29    pub fn inc_strong(&self) {
30        self.strong_count.set(self.strong_count.get() + 1)
31    }
32
33    pub fn dec_strong(&self) {
34        self.strong_count.set(self.strong_count.get() - 1)
35    }
36
37    pub fn inc_weak(&self) {
38        self.weak_count.set(self.weak_count.get() + 1)
39    }
40
41    pub fn dec_weak(&self) {
42        self.weak_count.set(self.weak_count.get() - 1)
43    }
44}
45
46enum RefType {
47    Weak,
48    Strong,
49}
50
51struct RootedRcCommon<T> {
53    tag: Tag,
54    internal: Option<NonNull<RootedRcInternal<T>>>,
55}
56
57impl<T> RootedRcCommon<T> {
58    pub fn new(root: &Root, val: T) -> Self {
59        Self {
60            tag: root.tag(),
61            internal: Some(
62                NonNull::new(Box::into_raw(Box::new(RootedRcInternal::new(val)))).unwrap(),
63            ),
64        }
65    }
66
67    pub fn borrow_internal(&self, root: &Root) -> &RootedRcInternal<T> {
70        assert_eq!(
71            root.tag, self.tag,
72            "Tried using root {:?} instead of {:?}",
73            root.tag, self.tag
74        );
75        unsafe { self.internal.unwrap().as_ref() }
80    }
81
82    pub fn safely_drop(mut self, root: &Root, t: RefType) -> Option<T> {
84        let internal: &RootedRcInternal<T> = self.borrow_internal(root);
85        match t {
86            RefType::Weak => internal.dec_weak(),
87            RefType::Strong => internal.dec_strong(),
88        };
89        let strong_count = internal.strong_count.get();
90        let weak_count = internal.weak_count.get();
91
92        let val: Option<T> = if strong_count == 0 {
97            unsafe { internal.val.get().as_mut().unwrap().take() }
100        } else {
101            None
102        };
103
104        let internal: NonNull<RootedRcInternal<T>> = self.internal.take().unwrap();
106
107        if strong_count == 0 && weak_count == 0 {
109            drop(unsafe { Box::from_raw(internal.as_ptr()) });
113        }
114
115        val
116    }
117
118    pub fn clone(&self, root: &Root, t: RefType) -> Self {
119        let internal: &RootedRcInternal<T> = self.borrow_internal(root);
120        match t {
121            RefType::Weak => internal.inc_weak(),
122            RefType::Strong => internal.inc_strong(),
123        };
124        Self {
125            tag: self.tag,
126            internal: self.internal,
127        }
128    }
129}
130
131#[cfg(all(feature = "std", debug_assertions))]
132fn already_panicking() -> bool {
133    std::thread::panicking()
134}
135#[cfg(all(not(feature = "std"), debug_assertions))]
136fn already_panicking() -> bool {
137    false
138}
139
140impl<T> Drop for RootedRcCommon<T> {
141    #[inline]
142    fn drop(&mut self) {
143        if self.internal.is_some() {
144            log::error!("Dropped without calling `explicit_drop`");
146
147            #[cfg(debug_assertions)]
158            if !already_panicking() {
159                panic!("Dropped without calling `explicit_drop`");
160            }
161        }
162    }
163}
164
165unsafe impl<T: Sync + Send> Send for RootedRcCommon<T> {}
169unsafe impl<T: Sync + Send> Sync for RootedRcCommon<T> {}
170
171pub struct RootedRc<T> {
186    common: RootedRcCommon<T>,
187}
188
189impl<T> RootedRc<T> {
190    #[inline]
192    pub fn new(root: &Root, val: T) -> Self {
193        Self {
194            common: RootedRcCommon::new(root, val),
195        }
196    }
197
198    #[inline]
203    pub fn downgrade(this: &Self, root: &Root) -> RootedRcWeak<T> {
204        RootedRcWeak {
205            common: this.common.clone(root, RefType::Weak),
206        }
207    }
208
209    #[inline]
215    pub fn clone(&self, root: &Root) -> Self {
216        Self {
217            common: self.common.clone(root, RefType::Strong),
218        }
219    }
220
221    #[inline]
224    pub fn into_inner(this: Self, root: &Root) -> Option<T> {
225        this.common.safely_drop(root, RefType::Strong)
226    }
227
228    pub fn explicit_drop_recursive(
231        self,
232        root: &Root,
233        param: &T::ExplicitDropParam,
234    ) -> Option<T::ExplicitDropResult>
235    where
236        T: ExplicitDrop,
237    {
238        Self::into_inner(self, root).map(|val| val.explicit_drop(param))
239    }
240}
241
242impl<T> ExplicitDrop for RootedRc<T> {
243    type ExplicitDropParam = Root;
244
245    type ExplicitDropResult = ();
246
247    fn explicit_drop(self, root: &Self::ExplicitDropParam) -> Self::ExplicitDropResult {
251        self.common.safely_drop(root, RefType::Strong);
252    }
253}
254
255impl<T> core::ops::Deref for RootedRc<T> {
256    type Target = T;
257
258    #[inline]
259    fn deref(&self) -> &Self::Target {
260        let internal = unsafe { self.common.internal.unwrap().as_ref() };
266
267        let val = unsafe { &*internal.val.get() };
272        val.as_ref().unwrap()
273    }
274}
275
276#[cfg(test)]
277mod test_rooted_rc {
278    use std::{sync::Arc, thread};
279
280    use super::*;
281
282    #[test]
283    fn construct_and_drop() {
284        let root = Root::new();
285        let rc = RootedRc::new(&root, 0);
286        rc.explicit_drop(&root)
287    }
288
289    #[test]
290    #[cfg(debug_assertions)]
291    #[should_panic]
292    fn drop_without_lock_panics_with_debug_assertions() {
293        let root = Root::new();
294        drop(RootedRc::new(&root, 0));
295    }
296
297    #[test]
298    #[cfg(not(debug_assertions))]
299    fn drop_without_lock_leaks_without_debug_assertions() {
300        let root = Root::new();
301        let rc = std::rc::Rc::new(());
302        let rrc = RootedRc::new(&root, rc.clone());
303        drop(rrc);
304        assert_eq!(std::rc::Rc::strong_count(&rc), 2);
307    }
308
309    #[test]
310    fn send_to_worker_thread() {
311        let root = Root::new();
312        let rc = RootedRc::new(&root, 0);
313        thread::spawn(move || {
314            let _ = *rc + 2;
316            rc.explicit_drop(&root)
318        })
319        .join()
320        .unwrap();
321    }
322
323    #[test]
324    fn send_to_worker_thread_and_retrieve() {
325        let root = Root::new();
326        let root = thread::spawn(move || {
327            let rc = RootedRc::new(&root, 0);
328            rc.explicit_drop(&root);
329            root
330        })
331        .join()
332        .unwrap();
333        let rc = RootedRc::new(&root, 0);
334        rc.explicit_drop(&root)
335    }
336
337    #[test]
338    fn clone_to_worker_thread() {
339        let root = Root::new();
340        let rc = RootedRc::new(&root, 0);
341
342        let rc_thread = rc.clone(&root);
344
345        let root = thread::spawn(move || {
348            let _ = *rc_thread;
349            rc_thread.explicit_drop(&root);
350            root
351        })
352        .join()
353        .unwrap();
354
355        rc.explicit_drop(&root);
357    }
358
359    #[test]
360    fn threads_contend_over_lock() {
361        let root = Arc::new(std::sync::Mutex::new(Root::new()));
362        let rc = RootedRc::new(&root.lock().unwrap(), 0);
363
364        let threads: Vec<_> = (0..100)
365            .map(|_| {
366                let rc = rc.clone(&root.lock().unwrap());
368                let root = root.clone();
369
370                thread::spawn(move || {
371                    let rootlock = root.lock().unwrap();
372                    let rc2 = rc.clone(&rootlock);
373                    rc.explicit_drop(&rootlock);
374                    rc2.explicit_drop(&rootlock);
375                })
376            })
377            .collect();
378
379        for handle in threads {
380            handle.join().unwrap();
381        }
382
383        rc.explicit_drop(&root.lock().unwrap());
384    }
385
386    #[test]
387    fn into_inner_recursive() {
388        let root = Root::new();
389        let inner = RootedRc::new(&root, ());
390        let outer1 = RootedRc::new(&root, inner);
391        let outer2 = outer1.clone(&root);
392
393        assert!(RootedRc::into_inner(outer1, &root).is_none());
395
396        let inner = RootedRc::into_inner(outer2, &root).unwrap();
398
399        inner.explicit_drop(&root);
401    }
402
403    #[test]
404    fn explicit_drop() {
405        let root = Root::new();
406        let rc = RootedRc::new(&root, ());
407        rc.explicit_drop(&root);
408    }
409
410    #[test]
411    fn explicit_drop_recursive() {
412        struct MyOuter(RootedRc<()>);
415        impl ExplicitDrop for MyOuter {
416            type ExplicitDropParam = Root;
417            type ExplicitDropResult = ();
418
419            fn explicit_drop(self, root: &Self::ExplicitDropParam) -> Self::ExplicitDropResult {
420                self.0.explicit_drop(root);
421            }
422        }
423
424        let root = Root::new();
425        let inner = RootedRc::new(&root, ());
426        let outer1 = RootedRc::new(&root, MyOuter(inner));
427        let outer2 = RootedRc::new(&root, MyOuter(outer1.0.clone(&root)));
428        outer1.explicit_drop_recursive(&root, &root);
429        outer2.explicit_drop_recursive(&root, &root);
430    }
431}
432
433pub struct RootedRcWeak<T> {
434    common: RootedRcCommon<T>,
435}
436
437impl<T> RootedRcWeak<T> {
438    #[inline]
439    pub fn upgrade(&self, root: &Root) -> Option<RootedRc<T>> {
440        let internal = self.common.borrow_internal(root);
441
442        if internal.strong_count.get() == 0 {
443            return None;
444        }
445
446        Some(RootedRc {
447            common: self.common.clone(root, RefType::Strong),
448        })
449    }
450
451    #[inline]
457    pub fn clone(&self, root: &Root) -> Self {
458        Self {
459            common: self.common.clone(root, RefType::Weak),
460        }
461    }
462}
463
464impl<T> ExplicitDrop for RootedRcWeak<T> {
465    type ExplicitDropParam = Root;
466
467    type ExplicitDropResult = ();
468
469    #[inline]
470    fn explicit_drop(self, root: &Self::ExplicitDropParam) -> Self::ExplicitDropResult {
471        let val = self.common.safely_drop(root, RefType::Weak);
472        debug_assert!(val.is_none());
475    }
476}
477
478unsafe impl<T: Sync + Send> Send for RootedRcWeak<T> {}
482unsafe impl<T: Sync + Send> Sync for RootedRcWeak<T> {}
483
484#[cfg(test)]
485mod test_rooted_rc_weak {
486    use super::*;
487
488    #[test]
489    fn successful_upgrade() {
490        let root = Root::new();
491        let strong = RootedRc::new(&root, 42);
492        let weak = RootedRc::downgrade(&strong, &root);
493
494        let upgraded = weak.upgrade(&root).unwrap();
495
496        assert_eq!(*upgraded, *strong);
497
498        upgraded.explicit_drop(&root);
499        weak.explicit_drop(&root);
500        strong.explicit_drop(&root);
501    }
502
503    #[test]
504    fn failed_upgrade() {
505        let root = Root::new();
506        let strong = RootedRc::new(&root, 42);
507        let weak = RootedRc::downgrade(&strong, &root);
508
509        strong.explicit_drop(&root);
510
511        assert!(weak.upgrade(&root).is_none());
512
513        weak.explicit_drop(&root);
514    }
515
516    #[test]
517    #[cfg(debug_assertions)]
518    #[should_panic]
519    fn drop_without_lock_panics_with_debug_assertions() {
520        let root = Root::new();
521        let strong = RootedRc::new(&root, 42);
522        drop(RootedRc::downgrade(&strong, &root));
523        strong.explicit_drop(&root);
524    }
525
526    #[test]
528    fn circular_reference() {
529        std::thread_local! {
530            static THREAD_ROOT: Root = Root::new();
531        }
532
533        struct MyStruct {
534            weak_self: Cell<Option<RootedRcWeak<Self>>>,
536        }
537        impl MyStruct {
538            fn new() -> RootedRc<Self> {
539                THREAD_ROOT.with(|root| {
540                    let rv = RootedRc::new(
541                        root,
542                        MyStruct {
543                            weak_self: Cell::new(None),
544                        },
545                    );
546                    let weak = RootedRc::downgrade(&rv, root);
547                    rv.weak_self.set(Some(weak));
548                    rv
549                })
550            }
551        }
552        impl Drop for MyStruct {
553            fn drop(&mut self) {
554                let weak = self.weak_self.replace(None).unwrap();
555                THREAD_ROOT.with(|root| {
556                    weak.explicit_drop(root);
557                });
558            }
559        }
560
561        let val = MyStruct::new();
562        THREAD_ROOT.with(|root| {
563            val.explicit_drop(root);
564        })
565    }
566
567    #[test]
568    #[cfg(not(debug_assertions))]
569    fn drop_without_lock_doesnt_leak_value() {
570        let root = Root::new();
571        let rc = std::rc::Rc::new(());
572        let strong = RootedRc::new(&root, rc.clone());
573        drop(RootedRc::downgrade(&strong, &root));
574        strong.explicit_drop(&root);
575
576        assert_eq!(std::rc::Rc::strong_count(&rc), 1);
581    }
582}