shadow_shim/
tls.rs

1//! no_std thread-local storage
2//!
3//! This module provides a level of indirection for thread-local storage, with
4//! different options trading off performance and stability. See [`Mode`] for more
5//! about each of these.
6use core::cell::{Cell, UnsafeCell};
7use core::marker::PhantomData;
8use core::mem::MaybeUninit;
9use core::num::NonZeroUsize;
10use core::ops::Deref;
11use core::sync::atomic::{self, AtomicUsize};
12
13use num_enum::{IntoPrimitive, TryFromPrimitive};
14use rustix::process::Pid;
15use vasi_sync::atomic_tls_map::{self, AtomicTlsMap};
16use vasi_sync::lazy_lock::{self, LazyLock};
17
18use crate::mmap_box::MmapBox;
19
20/// Modes of operation for this module.
21#[derive(Debug, Eq, PartialEq, Copy, Clone, TryFromPrimitive, IntoPrimitive)]
22#[repr(i8)]
23pub enum Mode {
24    /// Delegate back to ELF native thread local storage. This is the fastest
25    /// option, and simplest with respect to our own code, but is unsound.
26    /// We should probably ultimately disable or remove it.
27    ///
28    /// In native thread local storage for ELF executables, an access to a
29    /// thread-local variable (with C storage specifier `__thread`) from a
30    /// dynamically shared object (like the Shadow shim) involves implicitly calling
31    /// the libc function `__tls_get_addr`. That function is *not* guaranteed to be
32    /// async-signal-safe (See `signal-safety(7)`), and can end up making system
33    /// calls and doing memory allocation. This has caused problems with some versions of
34    /// glibc (Can't find the issue #...), and recently when running managed
35    /// processed compiled with asan <https://github.com/shadow/shadow/issues/2790>.
36    ///
37    /// SAFETY: `__tls_get_addr` in the linked version of libc must not make
38    /// system calls or do anything async-signal unsafe. This basically
39    /// can't be ensured, but is often true in practice.
40    //
41    // TODO: I *think* if we want to avoid the shim linking with libc at all,
42    // we'll need to disable this mode at compile-time by removing it or making
43    // it a compile-time feature.
44    Native,
45    /// This mode takes advantage of ELF native thread local storage, but only
46    /// leverages it as a cheap-to-retrieve thread identifier. It does not call
47    /// into libc or store anything directly in the native thread local storage.
48    ///
49    /// In particular, based on 3.4.6 and 3.4.2 of [ELF-TLS], we know that we
50    /// can retrieve the "thread pointer" by loading offset zero of the `fs`
51    /// register; i.e. `%fs:0`.
52    ///
53    /// The contents of the data pointed to by the thread pointer are an
54    /// implementation detail of the compiler, linker, and libc, so we don't
55    /// depend on it. However it seems reasonable to assume that this address is
56    /// unique to each live thread, and doesn't change during the lifetime of a
57    /// thread. Therefore we use the address as a thread-identifier, which we in
58    /// turn use as key to our own allocated thread-local-storage.
59    ///
60    /// This mode is nearly as fast as native, but:
61    /// * Assumes that if the "thread pointer" in `%fs:0` is non-NULL for a
62    ///   given thread, that it is stable and unique for the lifetime of that
63    ///   thread.  This seems like a fairly reasonable assumption, and seems to
64    ///   hold so far, but isn't guaranteed.
65    /// * Requires that each thread using thread local storage from this module
66    ///   calls [`ThreadLocalStorage::unregister_current_thread`] before
67    ///   exiting, since the thread pointer may subsequently be used for another
68    ///   thread.
69    ///
70    /// [ELF-TLS]: "ELF Handling For Thread-Local Storage", by Ulrich Drepper.
71    /// <https://www.akkadia.org/drepper/tls.pdf>
72    ///
73    /// SAFETY: Requires that each thread using this thread local storage
74    /// calls [`ThreadLocalStorage::unregister_current_thread`] before exiting.
75    NativeTlsId,
76    /// This mode is similar to `NativeTlsId`, but instead of using the ELF thread
77    /// pointer to identify each thread, it uses the system thread ID as retrieved by
78    /// the `gettid` syscall.
79    ///
80    /// Unlike `NativeTlsId`, this approach doesn't rely on any assumptions about
81    /// the implementation details of thread local storage in the managed process.
82    /// It also *usually* still works without calling [`ThreadLocalStorage::unregister_current_thread`],
83    /// but technically still requires it to guarantee soundness, since thread
84    ///
85    /// Unfortunately this mode is *much* slower than the others.
86    ///
87    /// SAFETY: Each thread using this thread local storage must call
88    /// [`ThreadLocalStorage::unregister_current_thread`] before exiting.
89    #[allow(unused)]
90    Gettid,
91}
92
93/// This needs to be big enough to store all thread-local variables for a single
94/// thread. We fail at runtime if this limit is exceeded.
95pub const BYTES_PER_THREAD: usize = 1024;
96
97// Max threads for our slow TLS fallback mechanism.  We support recycling
98// storage of exited threads, so this is the max *concurrent* threads per
99// process.
100const TLS_FALLBACK_MAX_THREADS: usize = 100;
101
102/// An ELF thread pointer, as specified in
103/// <https://www.akkadia.org/drepper/tls.pdf)>
104///
105/// Guaranteed not to be zero/NULL.
106///
107/// Only useful for comparisons, since the contents of the pointer are an
108/// implementation detail of libc and the linker.
109#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
110struct ElfThreadPointer(NonZeroUsize);
111
112impl ElfThreadPointer {
113    /// Thread pointer for the current thread, if one is set.
114    pub fn current() -> Option<Self> {
115        // Based on 3.4.6 and 3.4.2 of [ELF-TLS], we retrieve the "thread
116        // pointer" by loading offset zero of the `fs` register; i.e. `%fs:0`.
117        let fs: usize;
118        unsafe { core::arch::asm!("mov {fs}, $fs:0x0", fs = out(reg) fs) };
119        NonZeroUsize::new(fs).map(Self)
120    }
121}
122
123// The alignment we choose here becomes the max alignment we can support for
124// [`ShimTlsVar`]. 16 is enough for most types, and is e.g. the alignment of
125// pointers returned by glibc's `malloc`, but we can increase as needed.
126#[repr(C, align(16))]
127struct TlsOneThreadStorage {
128    // Used as a backing store for instances of `ShimTlsVarStorage`. Must be
129    // initialized to zero so that the first time that a given range of bytes is
130    // interpreted as a `ShimTlsVarStorage`, `ShimTlsVarStorage::initd` is
131    // correctly `false`.
132    //
133    // `MaybeUninit` because after a given range starts to be used as
134    // `ShimTlsVarStorage<T>`, T may *uninitialize* some bytes e.g. due to
135    // padding or its own use of `MaybeUninit`.
136    bytes: UnsafeCell<[MaybeUninit<u8>; BYTES_PER_THREAD]>,
137}
138
139impl TlsOneThreadStorage {
140    /// # Safety
141    ///
142    /// * `alloc` must be dereferenceable and live for the lifetime of this
143    ///   process. (A zeroed buffer is a valid dereferenceable instance of
144    ///   `Self`)
145    /// * `alloc` must *only* be accessed through this function.
146    pub unsafe fn from_static_lifetime_zeroed_allocation(
147        alloc: *mut TlsOneThreadStorageAllocation,
148    ) -> &'static Self {
149        type Output = TlsOneThreadStorage;
150        static_assertions::assert_eq_align!(TlsOneThreadStorageAllocation, Output);
151        static_assertions::assert_eq_size!(TlsOneThreadStorageAllocation, Output);
152        unsafe { &*alloc.cast_const().cast::<Output>() }
153    }
154
155    #[allow(clippy::new_without_default)]
156    pub fn new() -> Self {
157        Self {
158            bytes: UnsafeCell::new([MaybeUninit::new(0); BYTES_PER_THREAD]),
159        }
160    }
161}
162
163/// This is a "proxy" type to `TlsOneThreadStorage` with the same size and alignment.
164///
165/// Unlike `TlsOneThreadStorage`, it is exposed to C, that C code can provide
166/// a "thread-local allocator" that we delegate to in [`Mode::Native`].
167#[repr(C, align(16))]
168#[derive(Copy, Clone)]
169pub struct TlsOneThreadStorageAllocation {
170    _bytes: [u8; BYTES_PER_THREAD],
171}
172static_assertions::assert_eq_align!(TlsOneThreadStorageAllocation, TlsOneThreadStorage);
173static_assertions::assert_eq_size!(TlsOneThreadStorageAllocation, TlsOneThreadStorage);
174
175/// An opaque, per-thread identifier. These are only guaranteed to be unique for
176/// *live* threads. See [`ThreadLocalStorage::unregister_current_thread`].
177#[derive(Debug, Eq, PartialEq, Copy, Clone)]
178pub struct ThreadLocalStorageKey(FastThreadId);
179
180/// An opaque, per-thread identifier. These are only guaranteed to be unique for
181/// *live* threads; in particular [`FastThreadId::ElfThreadPointer`] of a live
182/// thread can have the same value as a previously seen dead thread. See
183/// [`ThreadLocalStorage::unregister_current_thread`].
184///
185/// Internal implemenation of [`ThreadLocalStorageKey`]
186#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
187enum FastThreadId {
188    ElfThreadPointer(ElfThreadPointer),
189    NativeTid(Pid),
190}
191
192impl FastThreadId {
193    fn to_nonzero_usize(self) -> NonZeroUsize {
194        // Kernel-space addresses have the most-significant bit set.
195        // https://en.wikipedia.org/wiki/X86-64#Virtual_address_space_details
196        //
197        // Conversely, user-space addresses do not.
198        //
199        // The thread pointer value, when, set, should contain a user-space
200        // address. i.e.  this bit should be unset.
201        //
202        // Since Pids are 32-bits, we can therefore use this bit to distinguish the
203        // two "types" of thread IDs.
204        const KERNEL_SPACE_BIT: usize = 1 << 63;
205        match self {
206            FastThreadId::ElfThreadPointer(ElfThreadPointer(fs)) => {
207                assert_eq!(fs.get() & KERNEL_SPACE_BIT, 0);
208                fs.get().try_into().unwrap()
209            }
210            FastThreadId::NativeTid(t) => {
211                let pid = usize::try_from(t.as_raw_nonzero().get()).unwrap();
212                (pid | KERNEL_SPACE_BIT).try_into().unwrap()
213            }
214        }
215    }
216
217    /// Id for the current thread.
218    pub fn current() -> Self {
219        #[cfg(not(miri))]
220        {
221            ElfThreadPointer::current()
222                .map(Self::ElfThreadPointer)
223                .unwrap_or_else(|| Self::NativeTid(rustix::thread::gettid()))
224        }
225        #[cfg(miri)]
226        {
227            // In miri we can't use inline assembly the get the fs register or
228            // get a numeric thread ID. We have to generate synthetic IDs from
229            // `std::thread::ThreadId` instead.
230
231            use std::collections::HashMap;
232            use std::sync::Mutex;
233            use std::thread::ThreadId;
234
235            static SYNTHETIC_IDS: Mutex<Option<HashMap<ThreadId, FastThreadId>>> = Mutex::new(None);
236            static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
237
238            let mut synthetic_ids = SYNTHETIC_IDS.lock().unwrap();
239            let mut synthetic_ids = synthetic_ids.get_or_insert_with(|| HashMap::new());
240            let id = std::thread::current().id();
241            *synthetic_ids
242                .entry(std::thread::current().id())
243                .or_insert_with(|| {
244                    Self::ElfThreadPointer(ElfThreadPointer(
245                        NonZeroUsize::new(
246                            NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
247                        )
248                        .unwrap(),
249                    ))
250                })
251        }
252    }
253}
254
255struct TlsOneThreadStorageProducer {}
256impl
257    lazy_lock::Producer<
258        MmapBox<AtomicTlsMap<TLS_FALLBACK_MAX_THREADS, MmapBox<TlsOneThreadStorage>>>,
259    > for TlsOneThreadStorageProducer
260{
261    fn initialize(
262        self,
263    ) -> MmapBox<AtomicTlsMap<TLS_FALLBACK_MAX_THREADS, MmapBox<TlsOneThreadStorage>>> {
264        MmapBox::new(AtomicTlsMap::new())
265    }
266}
267
268/// Provider for thread local storage. For non-test usage, there should generally
269/// be a single process-wide instance.
270pub struct ThreadLocalStorage {
271    // Allocate lazily via `mmap`, to avoid unnecessarily consuming
272    // the memory in processes where we always use native thread local storage.
273    storages: LazyLock<
274        MmapBox<AtomicTlsMap<TLS_FALLBACK_MAX_THREADS, MmapBox<TlsOneThreadStorage>>>,
275        TlsOneThreadStorageProducer,
276    >,
277    // Next available offset to allocate within `storages`.
278    next_offset: AtomicUsize,
279    preferred_mode: Mode,
280}
281
282impl ThreadLocalStorage {
283    /// # Safety
284    ///
285    /// See [`Mode`] for detailed safety requirements. No matter the preferred
286    /// mode, we fall back to [`Mode::Gettid`] if native thread local storage
287    /// isn't set up for a given thread, so that mode's requirements must always
288    /// be met: each thread using this thread local storage must call
289    /// [`Self::unregister_current_thread`] before exiting.
290    pub const unsafe fn new(preferred_mode: Mode) -> Self {
291        Self {
292            storages: LazyLock::const_new(TlsOneThreadStorageProducer {}),
293            next_offset: AtomicUsize::new(0),
294            preferred_mode,
295        }
296    }
297
298    fn alloc_offset(&self, align: usize, size: usize) -> usize {
299        // The alignment we ensure here is an offset from the base of a [`ShimThreadLocalStorage`].
300        // It won't be meaningful if [`ShimThreadLocalStorage`] has a smaller alignment requirement
301        // than this variable.
302        assert!(align <= core::mem::align_of::<TlsOneThreadStorage>());
303        let mut next_offset_val = self.next_offset.load(atomic::Ordering::Relaxed);
304        loop {
305            // Create a synthetic pointer just so we can call `align_offset`
306            // instead of doing the fiddly math ourselves.  This is sound, but
307            // causes miri to generate a warning.  We should use
308            // `core::ptr::invalid` here once stabilized to make our intent
309            // explicit that yes, really, we want to make an invalid pointer
310            // that we have no intention of dereferencing.
311            let fake: *const u8 = next_offset_val as *const u8;
312            let this_var_offset = next_offset_val + fake.align_offset(align);
313
314            let next_next_offset_val = this_var_offset + size;
315            if next_next_offset_val > BYTES_PER_THREAD {
316                panic!(
317                    "Exceeded hard-coded limit of {BYTES_PER_THREAD} per thread of thread local storage"
318                );
319            }
320
321            match self.next_offset.compare_exchange(
322                next_offset_val,
323                next_next_offset_val,
324                atomic::Ordering::Relaxed,
325                atomic::Ordering::Relaxed,
326            ) {
327                Ok(_) => return this_var_offset,
328                Err(v) => {
329                    // We raced with another thread. This *shouldn't* happen in
330                    // the shadow shim, since only one thread is allowed to run
331                    // at a time, but handle it gracefully. Update the current
332                    // value of the atomic and try again.
333                    next_offset_val = v;
334                }
335            }
336        }
337    }
338
339    /// Returns thread local storage for the current thread. The raw byte contents
340    /// are initialized to zero.
341    fn current_thread_storage(&self) -> TlsOneThreadBackingStoreRef {
342        if let Some(ThreadLocalStorageKey(id)) = self.current_key() {
343            // SAFETY: `id` is unique to this live thread, and caller guarantees
344            // any previous thread with this `id` has been removed.
345            let res = unsafe {
346                self.storages
347                    .deref()
348                    .get_or_insert_with(id.to_nonzero_usize(), || {
349                        MmapBox::new(TlsOneThreadStorage::new())
350                    })
351            };
352            TlsOneThreadBackingStoreRef::Mapped(res)
353        } else {
354            // Use native (libc) TLS.
355            let alloc: *mut TlsOneThreadStorageAllocation =
356                unsafe { crate::bindings::shim_native_tls() };
357            TlsOneThreadBackingStoreRef::Native(unsafe {
358                TlsOneThreadStorage::from_static_lifetime_zeroed_allocation(alloc)
359            })
360        }
361    }
362
363    /// Release this thread's thread local storage and exit the thread.
364    ///
365    /// Should be called by every thread that accesses thread local storage.
366    /// This is a no-op when using native thread-local storage, but is required for
367    /// correctness otherwise, since thread IDs can be reused.
368    ///
369    /// Panics if there are still any live references to this thread's [`ShimTlsVar`]s.
370    ///
371    /// # Safety
372    ///
373    /// The calling thread must not access this [`ThreadLocalStorage`] again
374    /// before exiting.
375    pub unsafe fn unregister_current_thread(&self) {
376        if !self.storages.initd() {
377            // Nothing to do. Even if another thread happens to be initializing
378            // concurrently, we know the *current* thread isn't registered.
379            return;
380        }
381
382        let storages = self.storages.force();
383
384        let id = FastThreadId::current();
385        // SAFETY: `id` is unique to this live thread, and caller guarantees
386        // any previous thread with this `id` has been removed.
387        unsafe { storages.remove(id.to_nonzero_usize()) };
388    }
389
390    /// An opaque key referencing this thread's thread-local-storage.
391    ///
392    /// `None` if the current thread uses native TLS.
393    pub fn current_key(&self) -> Option<ThreadLocalStorageKey> {
394        match self.preferred_mode {
395            Mode::Native if ElfThreadPointer::current().is_some() => {
396                // Native (libc) TLS seems to be set up properly. We'll use that,
397                // so there is no storage key.
398                None
399            }
400            // Use our fallback mechanism.
401            _ => Some(ThreadLocalStorageKey(FastThreadId::current())),
402        }
403    }
404
405    /// Reassigns storage from `prev_id` to the current thread, and drops
406    /// storage for all other threads.
407    ///
408    /// Meant to be called after forking a new process from a thread with ID
409    /// `prev_id`.
410    ///
411    /// # Safety
412    ///
413    /// `self` must not be shared with any other threads. Typically this is ensured
414    /// by calling this function after `fork` (but *not* `vfork`), and before any
415    /// additional threads are created from the new process.
416    ///
417    /// Current thread must have the same native thread local storage as the
418    /// parent; It is sufficient for parent to *not* have used CLONE_SETTLS when
419    /// creating the current thread.
420    pub unsafe fn fork_from(&self, prev_key: Option<ThreadLocalStorageKey>) {
421        let prev_storage = prev_key.map(|id| {
422            // SAFETY: Previous thread doesn't exist in this process.
423            unsafe { self.storages.remove(id.0.to_nonzero_usize()).unwrap() }
424        });
425
426        // SAFETY: Caller guarantees nothing else is accessing thread local
427        // storage.
428        unsafe {
429            self.storages.forget_all();
430        }
431
432        let curr_key = self.current_key();
433        match (prev_storage, curr_key) {
434            (None, None) => {
435                // Both parent and current use native storage. Caller guarantees
436                // that it's the same storage (e.g. no CLONE_SETTLS flag).
437            }
438            (Some(prev_storage), Some(curr_key)) => {
439                // Move storage to new key.
440                unsafe {
441                    self.storages
442                        .get_or_insert_with(curr_key.0.to_nonzero_usize(), move || prev_storage)
443                };
444            }
445            _ => {
446                // Need to migrate thread local storage between native and table.
447                //
448                // table -> native might not be too bad. We should be able to write
449                // the storage we retrieved from the table into native TLS.
450                //
451                // Not sure how to implement native -> table. I think we'd need
452                // to make the backing storage clonable, and clone it from the
453                // parent process so that we can access it from the child.
454                unimplemented!()
455            }
456        }
457    }
458}
459
460enum TlsOneThreadBackingStoreRef<'tls> {
461    Native(&'static TlsOneThreadStorage),
462    Mapped(atomic_tls_map::Ref<'tls, MmapBox<TlsOneThreadStorage>>),
463}
464
465impl Deref for TlsOneThreadBackingStoreRef<'_> {
466    type Target = TlsOneThreadStorage;
467
468    fn deref(&self) -> &Self::Target {
469        match self {
470            TlsOneThreadBackingStoreRef::Native(n) => n,
471            TlsOneThreadBackingStoreRef::Mapped(m) => m,
472        }
473    }
474}
475
476/// One of these is placed in each thread's thread-local-storage, for each
477/// `ShimTlsVar`.
478///
479/// It is designed to be safely *zeroable*, and for all zeroes to be the correct
480/// initial state, indicating that the actual value of the variable hasn't yet
481/// been initialized. This is because the first access for each of these
482/// variables is from the middle of an array of zeroed bytes in
483/// `ShimThreadLocalStorage`.
484///
485/// TODO: Consider adding a process-global identifier for each [`ShimTlsVar`],
486/// and a map mapping each of those to the init state. We then wouldn't
487/// need the internal `initd` flag, and maybe wouldn't need this type at all.
488/// It would mean another map-lookup on every thread-local access, but it'd
489/// probably ok.
490struct ShimTlsVarStorage<T> {
491    // Whether the var has been initialized for this thread.
492    //
493    // For `bool`, the bit pattern 0 is guaranteed to represent
494    // `false`, and `Cell` has the same layout as its inner type. Hence,
495    // interpreting 0-bytes as `Self` is sound and correctly indicates that it hasn't
496    // been initialized.
497    // <https://doc.rust-lang.org/std/cell/struct.Cell.html#memory-layout>
498    // <https://doc.rust-lang.org/reference/types/boolean.html>
499    initd: Cell<bool>,
500    value: UnsafeCell<MaybeUninit<T>>,
501}
502
503impl<T> ShimTlsVarStorage<T> {
504    fn get(&self) -> &T {
505        assert!(self.initd.get());
506
507        // SAFETY: We've ensured this value is initialized, and that
508        // there are no exclusive references created after initialization.
509        unsafe { (*self.value.get()).assume_init_ref() }
510    }
511
512    fn ensure_init(&self, initializer: impl FnOnce() -> T) {
513        if !self.initd.get() {
514            // Initialize the value.
515
516            // SAFETY: This thread has exclusive access to the underlying storage.
517            // This is the only place we ever construct a mutable reference to this
518            // value, and we know we've never constructed a reference before, since
519            // the data isn't initialized.
520            let value: &mut MaybeUninit<T> = unsafe { &mut *self.value.get() };
521            value.write(initializer());
522            self.initd.set(true);
523        }
524    }
525}
526
527/// An initializer for internal use with `LazyLock`. We need an explicit type
528/// instead of just a closure so that we can name the type  in `ShimTlsVar`'s
529/// definition.
530struct OffsetInitializer<'tls> {
531    tls: &'tls ThreadLocalStorage,
532    align: usize,
533    size: usize,
534}
535
536impl<'tls> OffsetInitializer<'tls> {
537    pub const fn new<T>(tls: &'tls ThreadLocalStorage) -> Self {
538        Self {
539            tls,
540            align: core::mem::align_of::<ShimTlsVarStorage<T>>(),
541            size: core::mem::size_of::<ShimTlsVarStorage<T>>(),
542        }
543    }
544}
545
546impl lazy_lock::Producer<usize> for OffsetInitializer<'_> {
547    // Finds and assigns the next free and suitably aligned offset within
548    // thread-local-storage for a value of type `T`, initialized with function
549    // `F`.
550    fn initialize(self) -> usize {
551        self.tls.alloc_offset(self.align, self.size)
552    }
553}
554
555/// Thread local storage for a variable of type `T`, initialized on first access
556/// by each thread using a function of type `F`.
557///
558/// The `Drop` implementation of `T` is *not* called, e.g. when threads exit or
559/// this value itself is dropped.
560//
561// TODO: Consider changing API to only provide a `with` method instead of
562// allowing access to `'static` references. This would let us validate in
563// [`ThreadLocalStorage::unregister_current_thread`] that no variables are
564// currently being accessed and enforce that none are accessed afterwards, and
565// potentially let us run `Drop` impls (though I think we'd also need an
566// allocator for the latter).
567pub struct ShimTlsVar<'tls, T, F = fn() -> T>
568where
569    F: Fn() -> T,
570{
571    tls: &'tls ThreadLocalStorage,
572    // We wrap in a lazy lock to support const initialization of `Self`.
573    offset: LazyLock<usize, OffsetInitializer<'tls>>,
574    f: F,
575    _phantom: PhantomData<T>,
576}
577// SAFETY: Still `Sync` even if T is `!Sync`, since each thread gets its own
578// instance of the value. `F` must still be `Sync`, though, since that *is*
579// shared across threads.
580unsafe impl<T, F> Sync for ShimTlsVar<'_, T, F> where F: Sync + Fn() -> T {}
581
582impl<'tls, T, F> ShimTlsVar<'tls, T, F>
583where
584    F: Fn() -> T,
585{
586    /// Create a variable that will be uniquely instantiated for each thread,
587    /// initialized with `f` on first access by each thread.
588    ///
589    /// Typically this should go in a `static`.
590    pub const fn new(tls: &'tls ThreadLocalStorage, f: F) -> Self {
591        Self {
592            tls,
593            offset: LazyLock::const_new(OffsetInitializer::new::<T>(tls)),
594            f,
595            _phantom: PhantomData,
596        }
597    }
598
599    /// Access the inner value.
600    ///
601    /// The returned wrapper can't be sent to or shared with other threads,
602    /// since the underlying storage is invalidated when the originating thread
603    /// calls [`ThreadLocalStorage::unregister_current_thread`].
604    pub fn get<'var>(&'var self) -> TlsVarRef<'tls, 'var, T, F> {
605        // SAFETY: This offset into TLS storage is a valid instance of
606        // `ShimTlsVarStorage<T>`. We've ensured the correct size and alignment,
607        // and the backing bytes have been initialized to 0.
608        unsafe { TlsVarRef::new(self) }
609    }
610}
611
612/// A reference to a single thread's instance of a TLS variable [`ShimTlsVar`].
613pub struct TlsVarRef<'tls, 'var, T, F: Fn() -> T> {
614    storage: TlsOneThreadBackingStoreRef<'tls>,
615    offset: usize,
616
617    // Force to be !Sync and !Send.
618    _phantom: core::marker::PhantomData<*mut T>,
619    // Defensively bind to lifetime of `ShimTlsVar`.  Currently not technically
620    // required, since we don't "deallocate" the backing storage of a `ShimTlsVar`
621    // that's uninitialized, and a no-op in "standard" usage since `ShimTlsVar`s
622    // generally have a `'static` lifetime, but let's avoid a potential
623    // surprising lifetime extension that we shouldn't need.
624    _phantom_lifetime: core::marker::PhantomData<&'var ShimTlsVar<'tls, T, F>>,
625}
626// Double check `!Send` and `!Sync`.
627static_assertions::assert_not_impl_any!(TlsVarRef<'static, 'static, (), fn() -> ()>: Send, Sync);
628
629impl<'tls, 'var, T, F: Fn() -> T> TlsVarRef<'tls, 'var, T, F> {
630    /// # Safety
631    ///
632    /// There must be an initialized instance of `ShimTlsVarStorage<T> at the
633    /// address of `&storage.bytes[offset]`.
634    unsafe fn new(var: &'var ShimTlsVar<'tls, T, F>) -> Self {
635        let storage = var.tls.current_thread_storage();
636        let offset = *var.offset.force();
637        let this = Self {
638            storage,
639            offset,
640            _phantom: PhantomData,
641            _phantom_lifetime: PhantomData,
642        };
643        this.var_storage().ensure_init(&var.f);
644        this
645    }
646
647    fn var_storage(&self) -> &ShimTlsVarStorage<T> {
648        // SAFETY: We ensured `offset` is in bounds at construction time.
649        let this_var_bytes: *mut u8 = {
650            let storage: *mut [MaybeUninit<u8>; BYTES_PER_THREAD] = self.storage.bytes.get();
651            let storage: *mut u8 = storage.cast();
652            unsafe { storage.add(self.offset) }
653        };
654
655        let this_var: *const ShimTlsVarStorage<T> = this_var_bytes as *const ShimTlsVarStorage<T>;
656        assert_eq!(
657            this_var.align_offset(core::mem::align_of::<ShimTlsVarStorage<T>>()),
658            0
659        );
660        // SAFETY: The TLS bytes for each thread are initialized to 0, and
661        // all-zeroes is a valid value of `ShimTlsVarStorage<T>`.
662        //
663        // We've ensure proper alignment when calculating the offset,
664        // and verified in the assertion just above.
665        let this_var: &ShimTlsVarStorage<T> = unsafe { &*this_var };
666        this_var
667    }
668}
669
670// there are multiple named lifetimes, so let's just be explicit about them rather than hide them
671#[allow(clippy::needless_lifetimes)]
672impl<'tls, 'var, T, F: Fn() -> T> Deref for TlsVarRef<'tls, 'var, T, F> {
673    type Target = T;
674
675    fn deref(&self) -> &Self::Target {
676        self.var_storage().get()
677    }
678}
679
680#[cfg(test)]
681mod test {
682    use core::cell::RefCell;
683    use core::sync::atomic::{self, AtomicI8, AtomicI16, AtomicI32};
684
685    use super::*;
686
687    #[cfg(miri)]
688    const MODES: &[Mode] = &[Mode::NativeTlsId, Mode::Gettid];
689    #[cfg(not(miri))]
690    const MODES: &[Mode] = &[Mode::Native, Mode::NativeTlsId, Mode::Gettid];
691
692    #[cfg(not(miri))]
693    #[test_log::test]
694    fn test_compile_static_native() {
695        static TLS: ThreadLocalStorage = unsafe { ThreadLocalStorage::new(Mode::Native) };
696        static MY_VAR: ShimTlsVar<u32> = ShimTlsVar::new(&TLS, || 42);
697        assert_eq!(*MY_VAR.get(), 42);
698        unsafe { TLS.unregister_current_thread() };
699    }
700
701    #[test_log::test]
702    fn test_compile_static_native_tls_id() {
703        static TLS: ThreadLocalStorage = unsafe { ThreadLocalStorage::new(Mode::NativeTlsId) };
704        static MY_VAR: ShimTlsVar<u32> = ShimTlsVar::new(&TLS, || 42);
705        assert_eq!(*MY_VAR.get(), 42);
706        unsafe { TLS.unregister_current_thread() };
707    }
708
709    #[test_log::test]
710    fn test_compile_static_gettid() {
711        static TLS: ThreadLocalStorage = unsafe { ThreadLocalStorage::new(Mode::Gettid) };
712        static MY_VAR: ShimTlsVar<u32> = ShimTlsVar::new(&TLS, || 42);
713        assert_eq!(*MY_VAR.get(), 42);
714        unsafe { TLS.unregister_current_thread() };
715    }
716
717    #[test_log::test]
718    fn test_minimal() {
719        for mode in MODES {
720            let tls = unsafe { ThreadLocalStorage::new(*mode) };
721            let var: ShimTlsVar<u32> = ShimTlsVar::new(&tls, || 0);
722            assert_eq!(*var.get(), 0);
723            unsafe { tls.unregister_current_thread() };
724        }
725    }
726
727    #[test]
728    #[should_panic(expected = "Removed key while references still held")]
729    fn test_panic() {
730        let tls = unsafe { ThreadLocalStorage::new(Mode::Gettid) };
731        let var: ShimTlsVar<u32> = ShimTlsVar::new(&tls, || 0);
732        let _var_ref = var.get();
733        // This should panic since we still have a reference
734        unsafe { tls.unregister_current_thread() };
735    }
736
737    #[test_log::test]
738    fn test_single_thread_mutate() {
739        for mode in MODES {
740            let tls = unsafe { ThreadLocalStorage::new(*mode) };
741            let my_var: ShimTlsVar<RefCell<u32>> = ShimTlsVar::new(&tls, || RefCell::new(0));
742            assert_eq!(*my_var.get().borrow(), 0);
743            *my_var.get().borrow_mut() = 42;
744            assert_eq!(*my_var.get().borrow(), 42);
745            unsafe { tls.unregister_current_thread() };
746        }
747    }
748
749    #[test_log::test]
750    fn test_multithread_mutate() {
751        for mode in MODES {
752            let tls = unsafe { ThreadLocalStorage::new(*mode) };
753            let my_var: ShimTlsVar<RefCell<i32>> = ShimTlsVar::new(&tls, || RefCell::new(0));
754            std::thread::scope(|scope| {
755                let tls = &tls;
756                let my_var = &my_var;
757                let threads = (0..10).map(|i| {
758                    scope.spawn(move || {
759                        assert_eq!(*my_var.get().borrow(), 0);
760                        *my_var.get().borrow_mut() = i;
761                        assert_eq!(*my_var.get().borrow(), i);
762                        unsafe { tls.unregister_current_thread() };
763                    })
764                });
765                for t in threads {
766                    t.join().unwrap();
767                }
768                unsafe { tls.unregister_current_thread() };
769            });
770        }
771    }
772
773    #[test_log::test]
774    fn test_multithread_mutate_small_alignment() {
775        for mode in MODES {
776            let tls = unsafe { ThreadLocalStorage::new(*mode) };
777            // Normally it'd make more sense to use cheaper interior mutability
778            // such as `RefCell` or `Cell`, but here we want to ensure the alignment is 1
779            // to validate that we don't overlap storage.
780            let my_var: ShimTlsVar<AtomicI8> = ShimTlsVar::new(&tls, || AtomicI8::new(0));
781            std::thread::scope(|scope| {
782                let tls = &tls;
783                let my_var = &my_var;
784                let threads = (0..10).map(move |i| {
785                    scope.spawn(move || {
786                        assert_eq!(my_var.get().load(atomic::Ordering::Relaxed), 0);
787                        my_var.get().store(i, atomic::Ordering::Relaxed);
788                        assert_eq!(my_var.get().load(atomic::Ordering::Relaxed), i);
789                        unsafe { tls.unregister_current_thread() };
790                    })
791                });
792                for t in threads {
793                    t.join().unwrap();
794                }
795                unsafe { tls.unregister_current_thread() };
796            });
797        }
798    }
799
800    #[test_log::test]
801    fn test_multithread_mutate_mixed_alignments() {
802        for mode in MODES {
803            let tls = unsafe { ThreadLocalStorage::new(*mode) };
804            let my_i8: ShimTlsVar<AtomicI8> = ShimTlsVar::new(&tls, || AtomicI8::new(0));
805            let my_i16: ShimTlsVar<AtomicI16> = ShimTlsVar::new(&tls, || AtomicI16::new(0));
806            let my_i32: ShimTlsVar<AtomicI32> = ShimTlsVar::new(&tls, || AtomicI32::new(0));
807            std::thread::scope(|scope| {
808                let tls = &tls;
809                let my_i8 = &my_i8;
810                let my_i16 = &my_i16;
811                let my_i32 = &my_i32;
812                let threads = (0..10).map(|i| {
813                    scope.spawn(move || {
814                        // Access out of alignment order
815                        assert_eq!(my_i8.get().load(atomic::Ordering::Relaxed), 0);
816                        assert_eq!(my_i32.get().load(atomic::Ordering::Relaxed), 0);
817                        assert_eq!(my_i16.get().load(atomic::Ordering::Relaxed), 0);
818
819                        // Order shouldn't matter here, but change it from above anyway.
820                        my_i32.get().store(i, atomic::Ordering::Relaxed);
821                        my_i8.get().store(i as i8, atomic::Ordering::Relaxed);
822                        my_i16.get().store(i as i16, atomic::Ordering::Relaxed);
823
824                        assert_eq!(my_i16.get().load(atomic::Ordering::Relaxed), i as i16);
825                        assert_eq!(my_i32.get().load(atomic::Ordering::Relaxed), i);
826                        assert_eq!(my_i8.get().load(atomic::Ordering::Relaxed), i as i8);
827                        unsafe { tls.unregister_current_thread() };
828                    })
829                });
830                for t in threads {
831                    t.join().unwrap();
832                }
833                unsafe { tls.unregister_current_thread() };
834            });
835        }
836    }
837}