1use lazy_static::lazy_static;
26
27use shadow_pod::Pod;
28use vasi::VirtualAddressSpaceIndependent;
29use vasi_sync::scmutex::SelfContainedMutex;
30
31pub fn shmalloc<T>(val: T) -> ShMemBlock<'static, T>
34where
35    T: Sync + VirtualAddressSpaceIndependent,
36{
37    register_teardown();
38    SHMALLOC.lock().alloc(val)
39}
40
41pub fn shfree<T>(block: ShMemBlock<'static, T>)
43where
44    T: Sync + VirtualAddressSpaceIndependent,
45{
46    SHMALLOC.lock().free(block);
47}
48
49pub unsafe fn shdeserialize<T>(serialized: &ShMemBlockSerialized) -> ShMemBlockAlias<'static, T>
57where
58    T: Sync + VirtualAddressSpaceIndependent,
59{
60    unsafe { SHDESERIALIZER.lock().deserialize(serialized) }
61}
62
63#[cfg(test)]
64extern "C" fn shmalloc_teardown() {
65    SHMALLOC.lock().destruct();
66}
67
68#[cfg(test)]
72#[cfg_attr(miri, ignore)]
73fn register_teardown() {
74    extern crate std;
75    use std::sync::Once;
76
77    static START: Once = Once::new();
78    START.call_once(|| unsafe {
79        libc::atexit(shmalloc_teardown);
80    });
81}
82
83#[cfg(not(test))]
84fn register_teardown() {}
85
86lazy_static! {
89    static ref SHMALLOC: SelfContainedMutex<SharedMemAllocator<'static>> = {
90        let alloc = SharedMemAllocator::new();
91        SelfContainedMutex::new(alloc)
92    };
93    static ref SHDESERIALIZER: SelfContainedMutex<SharedMemDeserializer<'static>> = {
94        let deserial = SharedMemDeserializer::new();
95        SelfContainedMutex::new(deserial)
96    };
97}
98
99pub struct SharedMemAllocatorDropGuard(());
105
106impl SharedMemAllocatorDropGuard {
107    pub unsafe fn new() -> Self {
111        Self(())
112    }
113}
114
115impl Drop for SharedMemAllocatorDropGuard {
116    fn drop(&mut self) {
117        SHMALLOC.lock().destruct();
118    }
119}
120
121#[derive(Debug)]
130pub struct ShMemBlock<'allocator, T>
131where
132    T: Sync + VirtualAddressSpaceIndependent,
133{
134    block: *mut crate::shmalloc_impl::Block,
135    phantom: core::marker::PhantomData<&'allocator T>,
136}
137
138impl<T> ShMemBlock<'_, T>
139where
140    T: Sync + VirtualAddressSpaceIndependent,
141{
142    pub fn serialize(&self) -> ShMemBlockSerialized {
143        let serialized = SHMALLOC.lock().internal.serialize(self.block);
144        ShMemBlockSerialized {
145            internal: serialized,
146        }
147    }
148}
149
150unsafe impl<T> Sync for ShMemBlock<'_, T> where T: Sync + VirtualAddressSpaceIndependent {}
153unsafe impl<T> Send for ShMemBlock<'_, T> where T: Send + Sync + VirtualAddressSpaceIndependent {}
154
155impl<T> core::ops::Deref for ShMemBlock<'_, T>
156where
157    T: Sync + VirtualAddressSpaceIndependent,
158{
159    type Target = T;
160
161    fn deref(&self) -> &Self::Target {
162        let block = unsafe { &*self.block };
163        &block.get_ref::<T>()[0]
164    }
165}
166
167impl<T> core::ops::Drop for ShMemBlock<'_, T>
168where
169    T: Sync + VirtualAddressSpaceIndependent,
170{
171    fn drop(&mut self) {
172        if !self.block.is_null() {
173            SHMALLOC.lock().internal.dealloc(self.block);
175            self.block = core::ptr::null_mut();
176        }
177    }
178}
179
180#[derive(Debug)]
187pub struct ShMemBlockAlias<'deserializer, T>
188where
189    T: Sync + VirtualAddressSpaceIndependent,
190{
191    block: *mut crate::shmalloc_impl::Block,
192    phantom: core::marker::PhantomData<&'deserializer T>,
193}
194
195unsafe impl<T> Sync for ShMemBlockAlias<'_, T> where T: Sync + VirtualAddressSpaceIndependent {}
198unsafe impl<T> Send for ShMemBlockAlias<'_, T> where T: Send + Sync + VirtualAddressSpaceIndependent {}
199
200impl<T> core::ops::Deref for ShMemBlockAlias<'_, T>
201where
202    T: Sync + VirtualAddressSpaceIndependent,
203{
204    type Target = T;
205
206    fn deref(&self) -> &Self::Target {
207        let block = unsafe { &*self.block };
208        &block.get_ref::<T>()[0]
209    }
210}
211
212#[derive(Copy, Clone, Debug, VirtualAddressSpaceIndependent)]
213#[repr(transparent)]
214pub struct ShMemBlockSerialized {
215    internal: crate::shmalloc_impl::BlockSerialized,
216}
217
218unsafe impl Pod for ShMemBlockSerialized {}
219
220impl core::fmt::Display for ShMemBlockSerialized {
221    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
222        let s =
223            core::str::from_utf8(crate::util::trim_null_bytes(&self.internal.chunk_name).unwrap())
224                .unwrap();
225        write!(f, "{};{}", self.internal.offset, s)
226    }
227}
228
229impl core::str::FromStr for ShMemBlockSerialized {
230    type Err = anyhow::Error;
231
232    fn from_str(s: &str) -> anyhow::Result<Self> {
234        use core::fmt::Write;
235        use formatting_nostd::FormatBuffer;
236
237        if let Some((offset_str, path_str)) = s.split_once(';') {
238            let offset = offset_str
240                .parse::<isize>()
241                .map_err(Err::<(), core::num::ParseIntError>)
242                .unwrap();
243
244            let mut chunk_format = FormatBuffer::<{ crate::util::PATH_MAX_NBYTES }>::new();
245
246            write!(&mut chunk_format, "{}", &path_str).unwrap();
247
248            let mut chunk_name = crate::util::NULL_PATH_BUF;
249            chunk_name
250                .iter_mut()
251                .zip(chunk_format.as_str().as_bytes().iter())
252                .for_each(|(x, y)| *x = *y);
253
254            Ok(ShMemBlockSerialized {
255                internal: crate::shmalloc_impl::BlockSerialized { chunk_name, offset },
256            })
257        } else {
258            Err(anyhow::anyhow!("missing ;"))
259        }
260    }
261}
262
263pub struct SharedMemAllocator<'alloc> {
268    internal: crate::shmalloc_impl::FreelistAllocator,
269    nallocs: isize,
270    phantom: core::marker::PhantomData<&'alloc ()>,
271}
272
273impl<'alloc> SharedMemAllocator<'alloc> {
274    fn new() -> Self {
275        let mut internal = crate::shmalloc_impl::FreelistAllocator::new();
276        internal.init().unwrap();
277
278        Self {
279            internal,
280            nallocs: 0,
281            phantom: Default::default(),
282        }
283    }
284
285    fn alloc<T: Sync + VirtualAddressSpaceIndependent>(&mut self, val: T) -> ShMemBlock<'alloc, T> {
287        let t_nbytes: usize = core::mem::size_of::<T>();
288        let t_alignment: usize = core::mem::align_of::<T>();
289
290        let block = self.internal.alloc(t_nbytes, t_alignment);
291        unsafe {
292            (*block).get_mut_ref::<T>()[0] = val;
293        }
294
295        self.nallocs += 1;
296        ShMemBlock::<'alloc, T> {
297            block,
298            phantom: Default::default(),
299        }
300    }
301
302    fn free<T: Sync + VirtualAddressSpaceIndependent>(&mut self, mut block: ShMemBlock<'alloc, T>) {
303        self.nallocs -= 1;
304        block.block = core::ptr::null_mut();
305        self.internal.dealloc(block.block);
306    }
307
308    fn destruct(&mut self) {
309        self.internal.destruct();
318    }
319}
320
321unsafe impl Send for SharedMemAllocator<'_> {}
322unsafe impl Sync for SharedMemAllocator<'_> {}
323
324pub struct SharedMemDeserializer<'alloc> {
349    internal: crate::shmalloc_impl::FreelistDeserializer,
350    phantom: core::marker::PhantomData<&'alloc ()>,
351}
352
353impl<'alloc> SharedMemDeserializer<'alloc> {
354    fn new() -> Self {
355        let internal = crate::shmalloc_impl::FreelistDeserializer::new();
356
357        Self {
358            internal,
359            phantom: Default::default(),
360        }
361    }
362
363    pub unsafe fn deserialize<T>(
369        &mut self,
370        serialized: &ShMemBlockSerialized,
371    ) -> ShMemBlockAlias<'alloc, T>
372    where
373        T: Sync + VirtualAddressSpaceIndependent,
374    {
375        let block = self.internal.deserialize(&serialized.internal);
376
377        ShMemBlockAlias {
378            block,
379            phantom: Default::default(),
380        }
381    }
382}
383
384unsafe impl Send for SharedMemDeserializer<'_> {}
385unsafe impl Sync for SharedMemDeserializer<'_> {}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use rand::Rng;
391    use std::str::FromStr;
392    use std::string::ToString;
393    use std::sync::atomic::{AtomicI32, Ordering};
394
395    extern crate std;
396
397    #[test]
398    #[cfg_attr(miri, ignore)]
399    fn allocator_random_allocations() {
400        const NROUNDS: usize = 100;
401        let mut marked_blocks: std::vec::Vec<(u32, ShMemBlock<u32>)> = Default::default();
402        let mut rng = rand::rng();
403
404        let mut execute_round = || {
405            for i in 0..255 {
407                let b = shmalloc(i);
408                marked_blocks.push((i, b));
409            }
410
411            let n1: u8 = rng.random();
413
414            for _ in 0..n1 {
415                let last_marked_block = marked_blocks.pop().unwrap();
416                assert_eq!(last_marked_block.0, *last_marked_block.1);
417                shfree(last_marked_block.1);
418            }
419
420            for block in &marked_blocks {
422                assert_eq!(block.0, *block.1);
423            }
424        };
425
426        for _ in 0..NROUNDS {
427            execute_round();
428        }
429
430        while let Some(b) = marked_blocks.pop() {
431            shfree(b.1);
432        }
433    }
434
435    #[test]
436    #[cfg_attr(miri, ignore)]
437    fn round_trip_through_serializer() {
438        type T = i32;
439        let x: T = 42;
440
441        let original_block: ShMemBlock<T> = shmalloc(x);
442        {
443            let serialized_block = original_block.serialize();
444            let serialized_str = serialized_block.to_string();
445            let serialized_block = ShMemBlockSerialized::from_str(&serialized_str).unwrap();
446            let block = unsafe { shdeserialize::<i32>(&serialized_block) };
447            assert_eq!(*block, 42);
448        }
449
450        shfree(original_block);
451    }
452
453    #[test]
454    #[cfg_attr(miri, ignore)]
456    fn mutations() {
457        type T = AtomicI32;
458        let original_block = shmalloc(AtomicI32::new(0));
459
460        let serialized_block = original_block.serialize();
461
462        let deserialized_block = unsafe { shdeserialize::<T>(&serialized_block) };
463
464        assert_eq!(original_block.load(Ordering::SeqCst), 0);
465        assert_eq!(deserialized_block.load(Ordering::SeqCst), 0);
466
467        original_block.store(10, Ordering::SeqCst);
469        assert_eq!(original_block.load(Ordering::SeqCst), 10);
470        assert_eq!(deserialized_block.load(Ordering::SeqCst), 10);
471
472        deserialized_block.store(20, Ordering::SeqCst);
474        assert_eq!(original_block.load(Ordering::SeqCst), 20);
475        assert_eq!(deserialized_block.load(Ordering::SeqCst), 20);
476
477        shfree(original_block);
478    }
479
480    #[test]
483    #[cfg_attr(miri, ignore)]
485    fn shmemblock_stable_pointer() {
486        type T = u32;
487        let original_block: ShMemBlock<T> = shmalloc(0);
488
489        let block_addr = &original_block as *const ShMemBlock<T>;
490        let data_addr = *original_block as *const T;
491
492        let block = Some(original_block);
496
497        let new_block_addr = block.as_ref().unwrap() as *const ShMemBlock<T>;
499        assert_ne!(block_addr, new_block_addr);
500
501        let new_data_addr = **(block.as_ref().unwrap()) as *const T;
503        assert_eq!(data_addr, new_data_addr);
504
505        #[allow(clippy::unnecessary_literal_unwrap)]
506        shfree(block.unwrap());
507    }
508
509    #[test]
511    #[cfg_attr(miri, ignore)]
513    fn shmemblockremote_stable_pointer() {
514        type T = u32;
515        let alloced_block: ShMemBlock<T> = shmalloc(0);
516
517        let block = unsafe { shdeserialize::<T>(&alloced_block.serialize()) };
518
519        let block_addr = &block as *const ShMemBlockAlias<T>;
520        let data_addr = *block as *const T;
521
522        let block = Some(block);
523
524        let new_block_addr = block.as_ref().unwrap() as *const ShMemBlockAlias<T>;
526        assert_ne!(block_addr, new_block_addr);
527
528        let new_data_addr = **(block.as_ref().unwrap()) as *const T;
530        assert_eq!(data_addr, new_data_addr);
531
532        shfree(alloced_block);
533    }
534}