1use lazy_static::lazy_static;
26
27use shadow_pod::Pod;
28use vasi::VirtualAddressSpaceIndependent;
29use vasi_sync::scmutex::SelfContainedMutex;
30
31pub fn shmalloc<T>(val: T) -> ShMemBlock<'static, T>
34where
35 T: Sync + VirtualAddressSpaceIndependent,
36{
37 register_teardown();
38 SHMALLOC.lock().alloc(val)
39}
40
41pub fn shfree<T>(block: ShMemBlock<'static, T>)
43where
44 T: Sync + VirtualAddressSpaceIndependent,
45{
46 SHMALLOC.lock().free(block);
47}
48
49pub unsafe fn shdeserialize<T>(serialized: &ShMemBlockSerialized) -> ShMemBlockAlias<'static, T>
57where
58 T: Sync + VirtualAddressSpaceIndependent,
59{
60 unsafe { SHDESERIALIZER.lock().deserialize(serialized) }
61}
62
63#[cfg(test)]
64extern "C" fn shmalloc_teardown() {
65 SHMALLOC.lock().destruct();
66}
67
68#[cfg(test)]
72#[cfg_attr(miri, ignore)]
73fn register_teardown() {
74 extern crate std;
75 use std::sync::Once;
76
77 static START: Once = Once::new();
78 START.call_once(|| unsafe {
79 libc::atexit(shmalloc_teardown);
80 });
81}
82
83#[cfg(not(test))]
84fn register_teardown() {}
85
86lazy_static! {
89 static ref SHMALLOC: SelfContainedMutex<SharedMemAllocator<'static>> = {
90 let alloc = SharedMemAllocator::new();
91 SelfContainedMutex::new(alloc)
92 };
93 static ref SHDESERIALIZER: SelfContainedMutex<SharedMemDeserializer<'static>> = {
94 let deserial = SharedMemDeserializer::new();
95 SelfContainedMutex::new(deserial)
96 };
97}
98
99pub struct SharedMemAllocatorDropGuard(());
105
106impl SharedMemAllocatorDropGuard {
107 pub unsafe fn new() -> Self {
111 Self(())
112 }
113}
114
115impl Drop for SharedMemAllocatorDropGuard {
116 fn drop(&mut self) {
117 SHMALLOC.lock().destruct();
118 }
119}
120
121#[derive(Debug)]
130pub struct ShMemBlock<'allocator, T>
131where
132 T: Sync + VirtualAddressSpaceIndependent,
133{
134 block: *mut crate::shmalloc_impl::Block,
135 phantom: core::marker::PhantomData<&'allocator T>,
136}
137
138impl<T> ShMemBlock<'_, T>
139where
140 T: Sync + VirtualAddressSpaceIndependent,
141{
142 pub fn serialize(&self) -> ShMemBlockSerialized {
143 let serialized = SHMALLOC.lock().internal.serialize(self.block);
144 ShMemBlockSerialized {
145 internal: serialized,
146 }
147 }
148}
149
150unsafe impl<T> Sync for ShMemBlock<'_, T> where T: Sync + VirtualAddressSpaceIndependent {}
153unsafe impl<T> Send for ShMemBlock<'_, T> where T: Send + Sync + VirtualAddressSpaceIndependent {}
154
155impl<T> core::ops::Deref for ShMemBlock<'_, T>
156where
157 T: Sync + VirtualAddressSpaceIndependent,
158{
159 type Target = T;
160
161 fn deref(&self) -> &Self::Target {
162 let block = unsafe { &*self.block };
163 &block.get_ref::<T>()[0]
164 }
165}
166
167impl<T> core::ops::Drop for ShMemBlock<'_, T>
168where
169 T: Sync + VirtualAddressSpaceIndependent,
170{
171 fn drop(&mut self) {
172 if !self.block.is_null() {
173 SHMALLOC.lock().internal.dealloc(self.block);
175 self.block = core::ptr::null_mut();
176 }
177 }
178}
179
180#[derive(Debug)]
187pub struct ShMemBlockAlias<'deserializer, T>
188where
189 T: Sync + VirtualAddressSpaceIndependent,
190{
191 block: *mut crate::shmalloc_impl::Block,
192 phantom: core::marker::PhantomData<&'deserializer T>,
193}
194
195unsafe impl<T> Sync for ShMemBlockAlias<'_, T> where T: Sync + VirtualAddressSpaceIndependent {}
198unsafe impl<T> Send for ShMemBlockAlias<'_, T> where T: Send + Sync + VirtualAddressSpaceIndependent {}
199
200impl<T> core::ops::Deref for ShMemBlockAlias<'_, T>
201where
202 T: Sync + VirtualAddressSpaceIndependent,
203{
204 type Target = T;
205
206 fn deref(&self) -> &Self::Target {
207 let block = unsafe { &*self.block };
208 &block.get_ref::<T>()[0]
209 }
210}
211
212#[derive(Copy, Clone, Debug, VirtualAddressSpaceIndependent)]
213#[repr(transparent)]
214pub struct ShMemBlockSerialized {
215 internal: crate::shmalloc_impl::BlockSerialized,
216}
217
218unsafe impl Pod for ShMemBlockSerialized {}
219
220impl core::fmt::Display for ShMemBlockSerialized {
221 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
222 let s =
223 core::str::from_utf8(crate::util::trim_null_bytes(&self.internal.chunk_name).unwrap())
224 .unwrap();
225 write!(f, "{};{}", self.internal.offset, s)
226 }
227}
228
229impl core::str::FromStr for ShMemBlockSerialized {
230 type Err = anyhow::Error;
231
232 fn from_str(s: &str) -> anyhow::Result<Self> {
234 use core::fmt::Write;
235 use formatting_nostd::FormatBuffer;
236
237 if let Some((offset_str, path_str)) = s.split_once(';') {
238 let offset = offset_str
240 .parse::<isize>()
241 .map_err(Err::<(), core::num::ParseIntError>)
242 .unwrap();
243
244 let mut chunk_format = FormatBuffer::<{ crate::util::PATH_MAX_NBYTES }>::new();
245
246 write!(&mut chunk_format, "{}", &path_str).unwrap();
247
248 let mut chunk_name = crate::util::NULL_PATH_BUF;
249 chunk_name
250 .iter_mut()
251 .zip(chunk_format.as_str().as_bytes().iter())
252 .for_each(|(x, y)| *x = *y);
253
254 Ok(ShMemBlockSerialized {
255 internal: crate::shmalloc_impl::BlockSerialized { chunk_name, offset },
256 })
257 } else {
258 Err(anyhow::anyhow!("missing ;"))
259 }
260 }
261}
262
263pub struct SharedMemAllocator<'alloc> {
268 internal: crate::shmalloc_impl::FreelistAllocator,
269 nallocs: isize,
270 phantom: core::marker::PhantomData<&'alloc ()>,
271}
272
273impl<'alloc> SharedMemAllocator<'alloc> {
274 fn new() -> Self {
275 let mut internal = crate::shmalloc_impl::FreelistAllocator::new();
276 internal.init().unwrap();
277
278 Self {
279 internal,
280 nallocs: 0,
281 phantom: Default::default(),
282 }
283 }
284
285 fn alloc<T: Sync + VirtualAddressSpaceIndependent>(&mut self, val: T) -> ShMemBlock<'alloc, T> {
287 let t_nbytes: usize = core::mem::size_of::<T>();
288 let t_alignment: usize = core::mem::align_of::<T>();
289
290 let block = self.internal.alloc(t_nbytes, t_alignment);
291 unsafe {
292 (*block).get_mut_ref::<T>()[0] = val;
293 }
294
295 self.nallocs += 1;
296 ShMemBlock::<'alloc, T> {
297 block,
298 phantom: Default::default(),
299 }
300 }
301
302 fn free<T: Sync + VirtualAddressSpaceIndependent>(&mut self, mut block: ShMemBlock<'alloc, T>) {
303 self.nallocs -= 1;
304 block.block = core::ptr::null_mut();
305 self.internal.dealloc(block.block);
306 }
307
308 fn destruct(&mut self) {
309 self.internal.destruct();
318 }
319}
320
321unsafe impl Send for SharedMemAllocator<'_> {}
322unsafe impl Sync for SharedMemAllocator<'_> {}
323
324pub struct SharedMemDeserializer<'alloc> {
349 internal: crate::shmalloc_impl::FreelistDeserializer,
350 phantom: core::marker::PhantomData<&'alloc ()>,
351}
352
353impl<'alloc> SharedMemDeserializer<'alloc> {
354 fn new() -> Self {
355 let internal = crate::shmalloc_impl::FreelistDeserializer::new();
356
357 Self {
358 internal,
359 phantom: Default::default(),
360 }
361 }
362
363 pub unsafe fn deserialize<T>(
369 &mut self,
370 serialized: &ShMemBlockSerialized,
371 ) -> ShMemBlockAlias<'alloc, T>
372 where
373 T: Sync + VirtualAddressSpaceIndependent,
374 {
375 let block = self.internal.deserialize(&serialized.internal);
376
377 ShMemBlockAlias {
378 block,
379 phantom: Default::default(),
380 }
381 }
382}
383
384unsafe impl Send for SharedMemDeserializer<'_> {}
385unsafe impl Sync for SharedMemDeserializer<'_> {}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use rand::Rng;
391 use std::str::FromStr;
392 use std::string::ToString;
393 use std::sync::atomic::{AtomicI32, Ordering};
394
395 extern crate std;
396
397 #[test]
398 #[cfg_attr(miri, ignore)]
399 fn allocator_random_allocations() {
400 const NROUNDS: usize = 100;
401 let mut marked_blocks: std::vec::Vec<(u32, ShMemBlock<u32>)> = Default::default();
402 let mut rng = rand::rng();
403
404 let mut execute_round = || {
405 for i in 0..255 {
407 let b = shmalloc(i);
408 marked_blocks.push((i, b));
409 }
410
411 let n1: u8 = rng.random();
413
414 for _ in 0..n1 {
415 let last_marked_block = marked_blocks.pop().unwrap();
416 assert_eq!(last_marked_block.0, *last_marked_block.1);
417 shfree(last_marked_block.1);
418 }
419
420 for block in &marked_blocks {
422 assert_eq!(block.0, *block.1);
423 }
424 };
425
426 for _ in 0..NROUNDS {
427 execute_round();
428 }
429
430 while let Some(b) = marked_blocks.pop() {
431 shfree(b.1);
432 }
433 }
434
435 #[test]
436 #[cfg_attr(miri, ignore)]
437 fn round_trip_through_serializer() {
438 type T = i32;
439 let x: T = 42;
440
441 let original_block: ShMemBlock<T> = shmalloc(x);
442 {
443 let serialized_block = original_block.serialize();
444 let serialized_str = serialized_block.to_string();
445 let serialized_block = ShMemBlockSerialized::from_str(&serialized_str).unwrap();
446 let block = unsafe { shdeserialize::<i32>(&serialized_block) };
447 assert_eq!(*block, 42);
448 }
449
450 shfree(original_block);
451 }
452
453 #[test]
454 #[cfg_attr(miri, ignore)]
456 fn mutations() {
457 type T = AtomicI32;
458 let original_block = shmalloc(AtomicI32::new(0));
459
460 let serialized_block = original_block.serialize();
461
462 let deserialized_block = unsafe { shdeserialize::<T>(&serialized_block) };
463
464 assert_eq!(original_block.load(Ordering::SeqCst), 0);
465 assert_eq!(deserialized_block.load(Ordering::SeqCst), 0);
466
467 original_block.store(10, Ordering::SeqCst);
469 assert_eq!(original_block.load(Ordering::SeqCst), 10);
470 assert_eq!(deserialized_block.load(Ordering::SeqCst), 10);
471
472 deserialized_block.store(20, Ordering::SeqCst);
474 assert_eq!(original_block.load(Ordering::SeqCst), 20);
475 assert_eq!(deserialized_block.load(Ordering::SeqCst), 20);
476
477 shfree(original_block);
478 }
479
480 #[test]
483 #[cfg_attr(miri, ignore)]
485 fn shmemblock_stable_pointer() {
486 type T = u32;
487 let original_block: ShMemBlock<T> = shmalloc(0);
488
489 let block_addr = &original_block as *const ShMemBlock<T>;
490 let data_addr = *original_block as *const T;
491
492 let block = Some(original_block);
496
497 let new_block_addr = block.as_ref().unwrap() as *const ShMemBlock<T>;
499 assert_ne!(block_addr, new_block_addr);
500
501 let new_data_addr = **(block.as_ref().unwrap()) as *const T;
503 assert_eq!(data_addr, new_data_addr);
504
505 #[allow(clippy::unnecessary_literal_unwrap)]
506 shfree(block.unwrap());
507 }
508
509 #[test]
511 #[cfg_attr(miri, ignore)]
513 fn shmemblockremote_stable_pointer() {
514 type T = u32;
515 let alloced_block: ShMemBlock<T> = shmalloc(0);
516
517 let block = unsafe { shdeserialize::<T>(&alloced_block.serialize()) };
518
519 let block_addr = &block as *const ShMemBlockAlias<T>;
520 let data_addr = *block as *const T;
521
522 let block = Some(block);
523
524 let new_block_addr = block.as_ref().unwrap() as *const ShMemBlockAlias<T>;
526 assert_ne!(block_addr, new_block_addr);
527
528 let new_data_addr = **(block.as_ref().unwrap()) as *const T;
530 assert_eq!(data_addr, new_data_addr);
531
532 shfree(alloced_block);
533 }
534}