bytemuck_derive/
traits.rs

1#![allow(unused_imports)]
2use std::{cmp, convert::TryFrom};
3
4use proc_macro2::{Ident, Span, TokenStream, TokenTree};
5use quote::{quote, ToTokens};
6use syn::{
7  parse::{Parse, ParseStream, Parser},
8  punctuated::Punctuated,
9  spanned::Spanned,
10  Result, *,
11};
12
13macro_rules! bail {
14  ($msg:expr $(,)?) => {
15    return Err(Error::new(Span::call_site(), &$msg[..]))
16  };
17
18  ( $msg:expr => $span_to_blame:expr $(,)? ) => {
19    return Err(Error::new_spanned(&$span_to_blame, $msg))
20  };
21}
22
23pub trait Derivable {
24  fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>;
25  fn implies_trait(_crate_name: &TokenStream) -> Option<TokenStream> {
26    None
27  }
28  fn asserts(
29    _input: &DeriveInput, _crate_name: &TokenStream,
30  ) -> Result<TokenStream> {
31    Ok(quote!())
32  }
33  fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> {
34    Ok(())
35  }
36  fn trait_impl(
37    _input: &DeriveInput, _crate_name: &TokenStream,
38  ) -> Result<(TokenStream, TokenStream)> {
39    Ok((quote!(), quote!()))
40  }
41  fn requires_where_clause() -> bool {
42    true
43  }
44  fn explicit_bounds_attribute_name() -> Option<&'static str> {
45    None
46  }
47
48  /// If this trait has a custom meaning for "perfect derive", this function
49  /// should be overridden to return `Some`.
50  ///
51  /// The default is "the fields of a struct; unions and enums not supported".
52  fn perfect_derive_fields(_input: &DeriveInput) -> Option<Fields> {
53    None
54  }
55}
56
57pub struct Pod;
58
59impl Derivable for Pod {
60  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
61    Ok(syn::parse_quote!(#crate_name::Pod))
62  }
63
64  fn asserts(
65    input: &DeriveInput, crate_name: &TokenStream,
66  ) -> Result<TokenStream> {
67    let repr = get_repr(&input.attrs)?;
68
69    let completly_packed =
70      repr.packed == Some(1) || repr.repr == Repr::Transparent;
71
72    if !completly_packed && !input.generics.params.is_empty() {
73      bail!("\
74        Pod requires cannot be derived for non-packed types containing \
75        generic parameters because the padding requirements can't be verified \
76        for generic non-packed structs\
77      " => input.generics.params.first().unwrap());
78    }
79
80    match &input.data {
81      Data::Struct(_) => {
82        let assert_no_padding = if !completly_packed {
83          Some(generate_assert_no_padding(input, None)?)
84        } else {
85          None
86        };
87        let assert_fields_are_pod = generate_fields_are_trait(
88          input,
89          None,
90          Self::ident(input, crate_name)?,
91        )?;
92
93        Ok(quote!(
94          #assert_no_padding
95          #assert_fields_are_pod
96        ))
97      }
98      Data::Enum(_) => bail!("Deriving Pod is not supported for enums"),
99      Data::Union(_) => bail!("Deriving Pod is not supported for unions"),
100    }
101  }
102
103  fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
104    let repr = get_repr(attributes)?;
105    match repr.repr {
106      Repr::C => Ok(()),
107      Repr::Transparent => Ok(()),
108      _ => {
109        bail!("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
110      }
111    }
112  }
113}
114
115pub struct AnyBitPattern;
116
117impl Derivable for AnyBitPattern {
118  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
119    Ok(syn::parse_quote!(#crate_name::AnyBitPattern))
120  }
121
122  fn implies_trait(crate_name: &TokenStream) -> Option<TokenStream> {
123    Some(quote!(#crate_name::Zeroable))
124  }
125
126  fn asserts(
127    input: &DeriveInput, crate_name: &TokenStream,
128  ) -> Result<TokenStream> {
129    match &input.data {
130      Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
131      Data::Struct(_) => {
132        generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
133      }
134      Data::Enum(_) => {
135        bail!("Deriving AnyBitPattern is not supported for enums")
136      }
137    }
138  }
139}
140
141pub struct Zeroable;
142
143/// Helper function to get the variant with discriminant zero (implicit or
144/// explicit).
145fn get_zero_variant(enum_: &DataEnum) -> Result<Option<&Variant>> {
146  let iter = VariantDiscriminantIterator::new(enum_.variants.iter());
147  let mut zero_variant = None;
148  for res in iter {
149    let (discriminant, variant) = res?;
150    if discriminant == 0 {
151      zero_variant = Some(variant);
152      break;
153    }
154  }
155  Ok(zero_variant)
156}
157
158impl Derivable for Zeroable {
159  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
160    Ok(syn::parse_quote!(#crate_name::Zeroable))
161  }
162
163  fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
164    let repr = get_repr(attributes)?;
165    match ty {
166      Data::Struct(_) => Ok(()),
167      Data::Enum(_) => {
168        if !matches!(
169          repr.repr,
170          Repr::C | Repr::Integer(_) | Repr::CWithDiscriminant(_)
171        ) {
172          bail!("Zeroable requires the enum to be an explicit #[repr(Int)] and/or #[repr(C)]")
173        }
174
175        // We ensure there is a zero variant in `asserts`, since it is needed
176        // there anyway.
177
178        Ok(())
179      }
180      Data::Union(_) => Ok(()),
181    }
182  }
183
184  fn asserts(
185    input: &DeriveInput, crate_name: &TokenStream,
186  ) -> Result<TokenStream> {
187    match &input.data {
188      Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
189      Data::Struct(_) => {
190        generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
191      }
192      Data::Enum(enum_) => {
193        let zero_variant = get_zero_variant(enum_)?;
194
195        if zero_variant.is_none() {
196          bail!("No variant's discriminant is 0")
197        };
198
199        generate_fields_are_trait(
200          input,
201          zero_variant,
202          Self::ident(input, crate_name)?,
203        )
204      }
205    }
206  }
207
208  fn explicit_bounds_attribute_name() -> Option<&'static str> {
209    Some("zeroable")
210  }
211
212  fn perfect_derive_fields(input: &DeriveInput) -> Option<Fields> {
213    match &input.data {
214      Data::Struct(struct_) => Some(struct_.fields.clone()),
215      Data::Enum(enum_) => {
216        // We handle `Err` returns from `get_zero_variant` in `asserts`, so it's
217        // fine to just ignore them here and return `None`.
218        // Otherwise, we clone the `fields` of the zero variant (if any).
219        Some(get_zero_variant(enum_).ok()??.fields.clone())
220      }
221      Data::Union(_) => None,
222    }
223  }
224}
225
226pub struct NoUninit;
227
228impl Derivable for NoUninit {
229  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
230    Ok(syn::parse_quote!(#crate_name::NoUninit))
231  }
232
233  fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
234    let repr = get_repr(attributes)?;
235    match ty {
236      Data::Struct(_) => match repr.repr {
237        Repr::C | Repr::Transparent => Ok(()),
238        _ => bail!("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"),
239      },
240      Data::Enum(DataEnum { variants,.. }) => {
241        if !enum_has_fields(variants.iter()) {
242          if matches!(repr.repr, Repr::C | Repr::Integer(_)) {
243            Ok(())
244          } else {
245            bail!("NoUninit requires the enum to be #[repr(C)] or #[repr(Int)]")
246          }
247        } else if matches!(repr.repr, Repr::Rust) {
248          bail!("NoUninit requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
249        } else {
250          Ok(())
251        }
252      },
253      Data::Union(_) => bail!("NoUninit can only be derived on enums and structs")
254    }
255  }
256
257  fn asserts(
258    input: &DeriveInput, crate_name: &TokenStream,
259  ) -> Result<TokenStream> {
260    if !input.generics.params.is_empty() {
261      bail!("NoUninit cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
262    }
263
264    match &input.data {
265      Data::Struct(DataStruct { .. }) => {
266        let assert_no_padding = generate_assert_no_padding(&input, None)?;
267        let assert_fields_are_no_padding = generate_fields_are_trait(
268          &input,
269          None,
270          Self::ident(input, crate_name)?,
271        )?;
272
273        Ok(quote!(
274            #assert_no_padding
275            #assert_fields_are_no_padding
276        ))
277      }
278      Data::Enum(DataEnum { variants, .. }) => {
279        if enum_has_fields(variants.iter()) {
280          // There are two different C representations for enums with fields:
281          // There's `#[repr(C)]`/`[repr(C, int)]` and `#[repr(int)]`.
282          // `#[repr(C)]` is equivalent to a struct containing the discriminant
283          // and a union of structs representing each variant's fields.
284          // `#[repr(C)]` is equivalent to a union containing structs of the
285          // discriminant and the fields.
286          //
287          // See https://doc.rust-lang.org/reference/type-layout.html#r-layout.repr.c.adt
288          // and https://doc.rust-lang.org/reference/type-layout.html#r-layout.repr.primitive.adt
289          //
290          // In practice the only difference between the two is whether and
291          // where padding bytes are placed. For `#[repr(C)]` enums, the first
292          // enum fields of all variants start at the same location (the first
293          // byte in the union). For `#[repr(int)]` enums, the structs
294          // representing each variant are layed out individually and padding
295          // does not depend on other variants, but only on the size of the
296          // discriminant and the alignment of the first field. The location of
297          // the first field might differ between variants, potentially
298          // resulting in less padding or padding placed later in the enum.
299          //
300          // The `NoUninit` derive macro asserts that no padding exists by
301          // removing all padding with `#[repr(packed)]` and checking that this
302          // doesn't change the size. Since the location and presence of
303          // padding bytes is the only difference between the two
304          // representations and we're removing all padding bytes, the resuling
305          // layout would identical for both representations. This means that
306          // we can just pick one of the representations and don't have to
307          // implement desugaring for both. We chose to implement the
308          // desugaring for `#[repr(int)]`.
309
310          let enum_discriminant = generate_enum_discriminant(input)?;
311          let variant_assertions = variants
312            .iter()
313            .map(|variant| {
314              let assert_no_padding =
315                generate_assert_no_padding(&input, Some(variant))?;
316              let assert_fields_are_no_padding = generate_fields_are_trait(
317                &input,
318                Some(variant),
319                Self::ident(input, crate_name)?,
320              )?;
321
322              Ok(quote!(
323                  #assert_no_padding
324                  #assert_fields_are_no_padding
325              ))
326            })
327            .collect::<Result<Vec<_>>>()?;
328          Ok(quote! {
329            const _: () = {
330              #enum_discriminant
331              #(#variant_assertions)*
332            };
333          })
334        } else {
335          Ok(quote!())
336        }
337      }
338      Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */
339    }
340  }
341
342  fn trait_impl(
343    _input: &DeriveInput, _crate_name: &TokenStream,
344  ) -> Result<(TokenStream, TokenStream)> {
345    Ok((quote!(), quote!()))
346  }
347}
348
349pub struct CheckedBitPattern;
350
351impl Derivable for CheckedBitPattern {
352  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
353    Ok(syn::parse_quote!(#crate_name::CheckedBitPattern))
354  }
355
356  fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
357    let repr = get_repr(attributes)?;
358    match ty {
359      Data::Struct(_) => match repr.repr {
360        Repr::C | Repr::Transparent => Ok(()),
361        _ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"),
362      },
363      Data::Enum(DataEnum { variants,.. }) => {
364        if !enum_has_fields(variants.iter()){
365          if matches!(repr.repr, Repr::C | Repr::Integer(_)) {
366            Ok(())
367          } else {
368            bail!("CheckedBitPattern requires the enum to be #[repr(C)] or #[repr(Int)]")
369          }
370        } else if matches!(repr.repr, Repr::Rust) {
371          bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
372        } else {
373          Ok(())
374        }
375      }
376      Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs")
377    }
378  }
379
380  fn asserts(
381    input: &DeriveInput, crate_name: &TokenStream,
382  ) -> Result<TokenStream> {
383    if !input.generics.params.is_empty() {
384      bail!("CheckedBitPattern cannot be derived for structs containing generic parameters");
385    }
386
387    match &input.data {
388      Data::Struct(DataStruct { .. }) => {
389        let assert_fields_are_maybe_pod = generate_fields_are_trait(
390          &input,
391          None,
392          Self::ident(input, crate_name)?,
393        )?;
394
395        Ok(assert_fields_are_maybe_pod)
396      }
397      // nothing needed, already guaranteed OK by NoUninit.
398      Data::Enum(_) => Ok(quote!()),
399      Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
400    }
401  }
402
403  fn trait_impl(
404    input: &DeriveInput, crate_name: &TokenStream,
405  ) -> Result<(TokenStream, TokenStream)> {
406    match &input.data {
407      Data::Struct(DataStruct { fields, .. }) => {
408        generate_checked_bit_pattern_struct(
409          &input.ident,
410          fields,
411          &input.attrs,
412          crate_name,
413        )
414      }
415      Data::Enum(DataEnum { variants, .. }) => {
416        generate_checked_bit_pattern_enum(input, variants, crate_name)
417      }
418      Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
419    }
420  }
421}
422
423pub struct TransparentWrapper;
424
425impl TransparentWrapper {
426  fn get_wrapper_type(
427    attributes: &[Attribute], fields: &Fields,
428  ) -> Option<TokenStream> {
429    let transparent_param = get_simple_attr(attributes, "transparent");
430    transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
431      let mut types = get_field_types(&fields);
432      let first_type = types.next();
433      if let Some(_) = types.next() {
434        // can't guess param type if there is more than one field
435        return None;
436      } else {
437        first_type.map(|ty| ty.to_token_stream())
438      }
439    })
440  }
441}
442
443impl Derivable for TransparentWrapper {
444  fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
445    let fields = get_struct_fields(input)?;
446
447    let ty = match Self::get_wrapper_type(&input.attrs, &fields) {
448      Some(ty) => ty,
449      None => bail!(
450        "\
451        when deriving TransparentWrapper for a struct with more than one field \
452        you need to specify the transparent field using #[transparent(T)]\
453      "
454      ),
455    };
456
457    Ok(syn::parse_quote!(#crate_name::TransparentWrapper<#ty>))
458  }
459
460  fn asserts(
461    input: &DeriveInput, crate_name: &TokenStream,
462  ) -> Result<TokenStream> {
463    let (impl_generics, _ty_generics, where_clause) =
464      input.generics.split_for_impl();
465    let fields = get_struct_fields(input)?;
466    let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
467      Some(wrapped_type) => wrapped_type.to_string(),
468      None => unreachable!(), /* other code will already reject this derive */
469    };
470    let mut wrapped_field_ty = None;
471    let mut nonwrapped_field_tys = vec![];
472    for field in fields.iter() {
473      let field_ty = &field.ty;
474      if field_ty.to_token_stream().to_string() == wrapped_type {
475        if wrapped_field_ty.is_some() {
476          bail!(
477            "TransparentWrapper can only have one field of the wrapped type"
478          );
479        }
480        wrapped_field_ty = Some(field_ty);
481      } else {
482        nonwrapped_field_tys.push(field_ty);
483      }
484    }
485    if let Some(wrapped_field_ty) = wrapped_field_ty {
486      Ok(quote!(
487        const _: () = {
488          #[repr(transparent)]
489          #[allow(clippy::multiple_bound_locations)]
490          struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
491          fn assert_zeroable<Z: #crate_name::Zeroable>() {}
492          #[allow(clippy::multiple_bound_locations)]
493          fn check #impl_generics () #where_clause {
494            #(
495              assert_zeroable::<#nonwrapped_field_tys>();
496            )*
497          }
498        };
499      ))
500    } else {
501      bail!("TransparentWrapper must have one field of the wrapped type")
502    }
503  }
504
505  fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
506    let repr = get_repr(attributes)?;
507
508    match repr.repr {
509      Repr::Transparent => Ok(()),
510      _ => {
511        bail!(
512          "TransparentWrapper requires the struct to be #[repr(transparent)]"
513        )
514      }
515    }
516  }
517
518  fn requires_where_clause() -> bool {
519    false
520  }
521}
522
523pub struct Contiguous;
524
525impl Derivable for Contiguous {
526  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
527    Ok(syn::parse_quote!(#crate_name::Contiguous))
528  }
529
530  fn trait_impl(
531    input: &DeriveInput, _crate_name: &TokenStream,
532  ) -> Result<(TokenStream, TokenStream)> {
533    let repr = get_repr(&input.attrs)?;
534
535    let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() {
536      integer_ty
537    } else {
538      bail!("Contiguous requires the enum to be #[repr(Int)]");
539    };
540
541    let variants = get_enum_variants(input)?;
542    if enum_has_fields(variants.clone()) {
543      return Err(Error::new_spanned(
544        &input,
545        "Only fieldless enums are supported",
546      ));
547    }
548
549    let mut variants_with_discriminant =
550      VariantDiscriminantIterator::new(variants);
551
552    let (min, max, count) = variants_with_discriminant.try_fold(
553      (i128::MAX, i128::MIN, 0),
554      |(min, max, count), res| {
555        let (discriminant, _variant) = res?;
556        Ok::<_, Error>((
557          i128::min(min, discriminant),
558          i128::max(max, discriminant),
559          count + 1,
560        ))
561      },
562    )?;
563
564    if max - min != count - 1 {
565      bail! {
566        "Contiguous requires the enum discriminants to be contiguous",
567      }
568    }
569
570    let min_lit = LitInt::new(&format!("{}", min), input.span());
571    let max_lit = LitInt::new(&format!("{}", max), input.span());
572
573    // `from_integer` and `into_integer` are usually provided by the trait's
574    // default implementation. We override this implementation because it
575    // goes through `transmute_copy`, which can lead to inefficient assembly as seen in https://github.com/Lokathor/bytemuck/issues/175 .
576
577    Ok((
578      quote!(),
579      quote! {
580          type Int = #integer_ty;
581
582          #[allow(clippy::missing_docs_in_private_items)]
583          const MIN_VALUE: #integer_ty = #min_lit;
584
585          #[allow(clippy::missing_docs_in_private_items)]
586          const MAX_VALUE: #integer_ty = #max_lit;
587
588          #[inline]
589          fn from_integer(value: Self::Int) -> Option<Self> {
590            #[allow(clippy::manual_range_contains)]
591            if Self::MIN_VALUE <= value && value <= Self::MAX_VALUE {
592              Some(unsafe { ::core::mem::transmute(value) })
593            } else {
594              None
595            }
596          }
597
598          #[inline]
599          fn into_integer(self) -> Self::Int {
600              self as #integer_ty
601          }
602      },
603    ))
604  }
605}
606
607fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> {
608  if let Data::Struct(DataStruct { fields, .. }) = &input.data {
609    Ok(fields)
610  } else {
611    bail!("deriving this trait is only supported for structs")
612  }
613}
614
615/// Extract the `Fields` off a `DeriveInput`, or, in the `enum` case, off
616/// those of the `enum_variant`, when provided (e.g., for `Zeroable`).
617///
618/// We purposely allow not providing an `enum_variant` for cases where
619/// the caller wants to reject supporting `enum`s (e.g., `NoPadding`).
620fn get_fields(
621  input: &DeriveInput, enum_variant: Option<&Variant>,
622) -> Result<Fields> {
623  match &input.data {
624    Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
625    Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
626    Data::Enum(_) => match enum_variant {
627      Some(variant) => Ok(variant.fields.clone()),
628      None => bail!("deriving this trait is not supported for enums"),
629    },
630  }
631}
632
633fn get_enum_variants<'a>(
634  input: &'a DeriveInput,
635) -> Result<impl Iterator<Item = &'a Variant> + Clone + 'a> {
636  if let Data::Enum(DataEnum { variants, .. }) = &input.data {
637    Ok(variants.iter())
638  } else {
639    bail!("deriving this trait is only supported for enums")
640  }
641}
642
643fn get_field_types<'a>(
644  fields: &'a Fields,
645) -> impl Iterator<Item = &'a Type> + 'a {
646  fields.iter().map(|field| &field.ty)
647}
648
649fn generate_checked_bit_pattern_struct(
650  input_ident: &Ident, fields: &Fields, attrs: &[Attribute],
651  crate_name: &TokenStream,
652) -> Result<(TokenStream, TokenStream)> {
653  let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span());
654
655  let repr = get_repr(attrs)?;
656
657  let field_names = fields
658    .iter()
659    .enumerate()
660    .map(|(i, field)| {
661      field.ident.clone().unwrap_or_else(|| {
662        Ident::new(&format!("field{}", i), input_ident.span())
663      })
664    })
665    .collect::<Vec<_>>();
666  let field_tys = fields.iter().map(|field| &field.ty).collect::<Vec<_>>();
667
668  let field_name = &field_names[..];
669  let field_ty = &field_tys[..];
670
671  Ok((
672    quote! {
673        #[doc = #GENERATED_TYPE_DOCUMENTATION]
674        #repr
675        #[derive(Clone, Copy, #crate_name::AnyBitPattern)]
676        #[allow(missing_docs)]
677        pub struct #bits_ty {
678            #(#field_name: <#field_ty as #crate_name::CheckedBitPattern>::Bits,)*
679        }
680
681        #[allow(unexpected_cfgs)]
682        const _: () = {
683          #[cfg(not(target_arch = "spirv"))]
684          impl ::core::fmt::Debug for #bits_ty {
685            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
686              let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty));
687              #(::core::fmt::DebugStruct::field(&mut debug_struct, ::core::stringify!(#field_name), &{ self.#field_name });)*
688              ::core::fmt::DebugStruct::finish(&mut debug_struct)
689            }
690          }
691        };
692    },
693    quote! {
694        type Bits = #bits_ty;
695
696        #[inline]
697        #[allow(clippy::double_comparisons, unused)]
698        fn is_valid_bit_pattern(bits: &#bits_ty) -> bool {
699            #(<#field_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(&{ bits.#field_name }) && )* true
700        }
701    },
702  ))
703}
704
705fn generate_checked_bit_pattern_enum(
706  input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
707  crate_name: &TokenStream,
708) -> Result<(TokenStream, TokenStream)> {
709  if enum_has_fields(variants.iter()) {
710    generate_checked_bit_pattern_enum_with_fields(input, variants, crate_name)
711  } else {
712    generate_checked_bit_pattern_enum_without_fields(
713      input, variants, crate_name,
714    )
715  }
716}
717
718fn generate_checked_bit_pattern_enum_without_fields(
719  input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
720  crate_name: &TokenStream,
721) -> Result<(TokenStream, TokenStream)> {
722  let span = input.span();
723  let mut variants_with_discriminant =
724    VariantDiscriminantIterator::new(variants.iter());
725
726  let (min, max, count) = variants_with_discriminant.try_fold(
727    (i128::MAX, i128::MIN, 0),
728    |(min, max, count), res| {
729      let (discriminant, _variant) = res?;
730      Ok::<_, Error>((
731        i128::min(min, discriminant),
732        i128::max(max, discriminant),
733        count + 1,
734      ))
735    },
736  )?;
737
738  let check = if count == 0 {
739    quote!(false)
740  } else if max - min == count - 1 {
741    // contiguous range
742    let min_lit = LitInt::new(&format!("{}", min), span);
743    let max_lit = LitInt::new(&format!("{}", max), span);
744
745    quote!(*bits >= #min_lit && *bits <= #max_lit)
746  } else {
747    // not contiguous range, check for each
748    let variant_discriminant_lits =
749      VariantDiscriminantIterator::new(variants.iter())
750        .map(|res| {
751          let (discriminant, _variant) = res?;
752          Ok(LitInt::new(&format!("{}", discriminant), span))
753        })
754        .collect::<Result<Vec<_>>>()?;
755
756    // count is at least 1
757    let first = &variant_discriminant_lits[0];
758    let rest = &variant_discriminant_lits[1..];
759
760    quote!(matches!(*bits, #first #(| #rest )*))
761  };
762
763  let (integer, defs) = get_enum_discriminant(input, crate_name)?;
764  Ok((
765    quote!(#defs),
766    quote! {
767        type Bits = #integer;
768
769        #[inline]
770        #[allow(clippy::double_comparisons)]
771        fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
772            #check
773        }
774    },
775  ))
776}
777
778fn generate_checked_bit_pattern_enum_with_fields(
779  input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
780  crate_name: &TokenStream,
781) -> Result<(TokenStream, TokenStream)> {
782  let representation = get_repr(&input.attrs)?;
783  let vis = &input.vis;
784
785  match representation.repr {
786    Repr::Rust => unreachable!(),
787    Repr::C | Repr::CWithDiscriminant(_) => {
788      let (integer, defs) = get_enum_discriminant(input, crate_name)?;
789      let input_ident = &input.ident;
790
791      let bits_repr = Representation { repr: Repr::C, ..representation };
792
793      // the enum manually re-configured as the actual tagged union it
794      // represents, thus circumventing the requirements rust imposes on
795      // the tag even when using #[repr(C)] enum layout
796      // see: https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields
797      let bits_ty_ident =
798        Ident::new(&format!("{input_ident}Bits"), input.span());
799
800      // the variants union part of the tagged union. These get put into a union
801      // which gets the AnyBitPattern derive applied to it, thus checking
802      // that the fields of the union obey the requriements of AnyBitPattern.
803      // The types that actually go in the union are one more level of
804      // indirection deep: we generate new structs for each variant
805      // (`variant_struct_definitions`) which themselves have the
806      // `CheckedBitPattern` derive applied, thus generating
807      // `{variant_struct_ident}Bits` structs, which are the ones that go
808      // into this union.
809      let variants_union_ident =
810        Ident::new(&format!("{}Variants", input.ident), input.span());
811
812      let variant_struct_idents = variants.iter().map(|v| {
813        Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
814      });
815
816      let variant_struct_definitions =
817        variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
818          let fields = v.fields.iter().map(|v| &v.ty);
819
820          quote! {
821            #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
822            #[repr(C)]
823            #vis struct #variant_struct_ident(#(#fields),*);
824          }
825        });
826
827      let union_fields = variant_struct_idents
828        .clone()
829        .zip(variants.iter())
830        .map(|(variant_struct_ident, v)| {
831          let variant_struct_bits_ident =
832            Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
833          let field_ident = &v.ident;
834          quote! {
835            #field_ident: #variant_struct_bits_ident
836          }
837        });
838
839      let variant_checks = variant_struct_idents
840        .clone()
841        .zip(VariantDiscriminantIterator::new(variants.iter()))
842        .zip(variants.iter())
843        .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
844          let (discriminant, _variant) = discriminant?;
845          let discriminant = LitInt::new(&discriminant.to_string(), v.span());
846          let ident = &v.ident;
847          Ok(quote! {
848            #discriminant => {
849              let payload = unsafe { &bits.payload.#ident };
850              <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
851            }
852          })
853        })
854        .collect::<Result<Vec<_>>>()?;
855
856      Ok((
857        quote! {
858          #defs
859
860          #[doc = #GENERATED_TYPE_DOCUMENTATION]
861          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
862          #bits_repr
863          #vis struct #bits_ty_ident {
864            tag: #integer,
865            payload: #variants_union_ident,
866          }
867
868          #[allow(unexpected_cfgs)]
869          const _: () = {
870            #[cfg(not(target_arch = "spirv"))]
871            impl ::core::fmt::Debug for #bits_ty_ident {
872              fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
873                let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
874                ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", &self.tag);
875                ::core::fmt::DebugStruct::field(&mut debug_struct, "payload", &self.payload);
876                ::core::fmt::DebugStruct::finish(&mut debug_struct)
877              }
878            }
879          };
880
881          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
882          #[repr(C)]
883          #[allow(non_snake_case)]
884          #vis union #variants_union_ident {
885            #(#union_fields,)*
886          }
887
888          #[allow(unexpected_cfgs)]
889          const _: () = {
890            #[cfg(not(target_arch = "spirv"))]
891            impl ::core::fmt::Debug for #variants_union_ident {
892              fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
893                let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident));
894                ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
895              }
896            }
897          };
898
899          #(#variant_struct_definitions)*
900        },
901        quote! {
902          type Bits = #bits_ty_ident;
903
904          #[inline]
905          #[allow(clippy::double_comparisons)]
906          fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
907            match bits.tag {
908              #(#variant_checks)*
909              _ => false,
910            }
911          }
912        },
913      ))
914    }
915    Repr::Transparent => {
916      if variants.len() != 1 {
917        bail!("enums with more than one variant cannot be transparent")
918      }
919
920      let variant = &variants[0];
921
922      let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span());
923      let fields = variant.fields.iter().map(|v| &v.ty);
924
925      Ok((
926        quote! {
927          #[doc = #GENERATED_TYPE_DOCUMENTATION]
928          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
929          #[repr(C)]
930          #vis struct #bits_ty(#(#fields),*);
931        },
932        quote! {
933          type Bits = <#bits_ty as #crate_name::CheckedBitPattern>::Bits;
934
935          #[inline]
936          #[allow(clippy::double_comparisons)]
937          fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
938            <#bits_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(bits)
939          }
940        },
941      ))
942    }
943    Repr::Integer(integer) => {
944      let bits_repr = Representation { repr: Repr::C, ..representation };
945      let input_ident = &input.ident;
946
947      // the enum manually re-configured as the union it represents. such a
948      // union is the union of variants as a repr(c) struct with the
949      // discriminator type inserted at the beginning. in our case we
950      // union the `Bits` representation of each variant rather than the variant
951      // itself, which we generate via a nested `CheckedBitPattern` derive
952      // on the `variant_struct_definitions` generated below.
953      //
954      // see: https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields
955      let bits_ty_ident =
956        Ident::new(&format!("{input_ident}Bits"), input.span());
957
958      let variant_struct_idents = variants.iter().map(|v| {
959        Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
960      });
961
962      let variant_struct_definitions =
963        variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
964          let fields = v.fields.iter().map(|v| &v.ty);
965
966          // adding the discriminant repr integer as first field, as described above
967          quote! {
968            #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
969            #[repr(C)]
970            #vis struct #variant_struct_ident(#integer, #(#fields),*);
971          }
972        });
973
974      let union_fields = variant_struct_idents
975        .clone()
976        .zip(variants.iter())
977        .map(|(variant_struct_ident, v)| {
978          let variant_struct_bits_ident =
979            Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
980          let field_ident = &v.ident;
981          quote! {
982            #field_ident: #variant_struct_bits_ident
983          }
984        });
985
986      let variant_checks = variant_struct_idents
987        .clone()
988        .zip(VariantDiscriminantIterator::new(variants.iter()))
989        .zip(variants.iter())
990        .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
991          let (discriminant, _variant) = discriminant?;
992          let discriminant = LitInt::new(&discriminant.to_string(), v.span());
993          let ident = &v.ident;
994          Ok(quote! {
995            #discriminant => {
996              let payload = unsafe { &bits.#ident };
997              <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
998            }
999          })
1000        })
1001        .collect::<Result<Vec<_>>>()?;
1002
1003      Ok((
1004        quote! {
1005          #[doc = #GENERATED_TYPE_DOCUMENTATION]
1006          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
1007          #bits_repr
1008          #[allow(non_snake_case)]
1009          #vis union #bits_ty_ident {
1010            __tag: #integer,
1011            #(#union_fields,)*
1012          }
1013
1014          #[allow(unexpected_cfgs)]
1015          const _: () = {
1016            #[cfg(not(target_arch = "spirv"))]
1017            impl ::core::fmt::Debug for #bits_ty_ident {
1018              fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1019                let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
1020                ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag });
1021                ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
1022              }
1023            }
1024          };
1025
1026          #(#variant_struct_definitions)*
1027        },
1028        quote! {
1029          type Bits = #bits_ty_ident;
1030
1031          #[inline]
1032          #[allow(clippy::double_comparisons)]
1033          fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
1034            match unsafe { bits.__tag } {
1035              #(#variant_checks)*
1036              _ => false,
1037            }
1038          }
1039        },
1040      ))
1041    }
1042  }
1043}
1044
1045/// Check that a struct or enum has no padding by asserting that the size of
1046/// the type is equal to the sum of the size of it's fields and discriminant
1047/// (for enums, this must be asserted for each variant).
1048fn generate_assert_no_padding(
1049  input: &DeriveInput, enum_variant: Option<&Variant>,
1050) -> Result<TokenStream> {
1051  let struct_type = &input.ident;
1052  let fields = get_fields(input, enum_variant)?;
1053
1054  // If the type is an enum, determine the type of its discriminant.
1055  let enum_discriminant = if matches!(input.data, Data::Enum(_)) {
1056    let ident =
1057      Ident::new(&format!("{}Discriminant", input.ident), input.ident.span());
1058    Some(ident.into_token_stream())
1059  } else {
1060    None
1061  };
1062
1063  // Prepend the type of the discriminant to the types of the fields.
1064  let mut field_types = enum_discriminant
1065    .into_iter()
1066    .chain(get_field_types(&fields).map(ToTokens::to_token_stream));
1067  let size_sum = if let Some(first) = field_types.next() {
1068    let size_first = quote!(::core::mem::size_of::<#first>());
1069    let size_rest = quote!(#( + ::core::mem::size_of::<#field_types>() )*);
1070
1071    quote!(#size_first #size_rest)
1072  } else {
1073    quote!(0)
1074  };
1075
1076  Ok(quote! {const _: fn() = || {
1077    #[doc(hidden)]
1078    struct TypeWithoutPadding([u8; #size_sum]);
1079    let _ = ::core::mem::transmute::<#struct_type, TypeWithoutPadding>;
1080  };})
1081}
1082
1083/// Check that all fields implement a given trait
1084fn generate_fields_are_trait(
1085  input: &DeriveInput, enum_variant: Option<&Variant>, trait_: syn::Path,
1086) -> Result<TokenStream> {
1087  let (impl_generics, _ty_generics, where_clause) =
1088    input.generics.split_for_impl();
1089  let fields = get_fields(input, enum_variant)?;
1090  let field_types = get_field_types(&fields);
1091  Ok(quote! {#(const _: fn() = || {
1092      #[allow(clippy::missing_const_for_fn)]
1093      #[doc(hidden)]
1094      fn check #impl_generics () #where_clause {
1095        fn assert_impl<T: #trait_>() {}
1096        assert_impl::<#field_types>();
1097      }
1098    };)*
1099  })
1100}
1101
1102/// Get the type of an enum's discriminant.
1103///
1104/// For `repr(int)` and `repr(C, int)` enums, this will return the known bare
1105/// integer type specified.
1106///
1107/// For `repr(C)` enums, this will extract the underlying size chosen by rustc.
1108/// It will return a token stream which is a type expression that evaluates to
1109/// a primitive integer type of this size, using our `EnumTagIntegerBytes`
1110/// trait.
1111///
1112/// For fieldless `repr(C)` enums, we can feed the size of the enum directly
1113/// into the trait.
1114///
1115/// For `repr(C)` enums with fields, we generate a new fieldless `repr(C)` enum
1116/// with the same variants, then use that in the calculation. This is the
1117/// specified behavior, see https://doc.rust-lang.org/stable/reference/type-layout.html#reprc-enums-with-fields
1118///
1119/// Returns a tuple of (type ident, auxiliary definitions)
1120fn get_enum_discriminant(
1121  input: &DeriveInput, crate_name: &TokenStream,
1122) -> Result<(TokenStream, TokenStream)> {
1123  let repr = get_repr(&input.attrs)?;
1124  match repr.repr {
1125    Repr::C => {
1126      let e = if let Data::Enum(e) = &input.data { e } else { unreachable!() };
1127      if enum_has_fields(e.variants.iter()) {
1128        // If the enum has fields, we must first isolate the discriminant by
1129        // removing all the fields.
1130        let enum_discriminant = generate_enum_discriminant(input)?;
1131        let discriminant_ident = Ident::new(
1132          &format!("{}Discriminant", input.ident),
1133          input.ident.span(),
1134        );
1135        Ok((
1136          quote!(<[::core::primitive::u8; ::core::mem::size_of::<#discriminant_ident>()] as #crate_name::derive::EnumTagIntegerBytes>::Integer),
1137          quote! {
1138            #enum_discriminant
1139          },
1140        ))
1141      } else {
1142        // If the enum doesn't have fields, we can just use it directly.
1143        let ident = &input.ident;
1144        Ok((
1145          quote!(<[::core::primitive::u8; ::core::mem::size_of::<#ident>()] as #crate_name::derive::EnumTagIntegerBytes>::Integer),
1146          quote!(),
1147        ))
1148      }
1149    }
1150    Repr::Integer(integer) | Repr::CWithDiscriminant(integer) => {
1151      Ok((quote!(#integer), quote!()))
1152    }
1153    _ => unreachable!(),
1154  }
1155}
1156
1157fn generate_enum_discriminant(input: &DeriveInput) -> Result<TokenStream> {
1158  let e = if let Data::Enum(e) = &input.data { e } else { unreachable!() };
1159  let repr = get_repr(&input.attrs)?;
1160  let repr = match repr.repr {
1161    Repr::C => quote!(#[repr(C)]),
1162    Repr::Integer(int) | Repr::CWithDiscriminant(int) => quote!(#[repr(#int)]),
1163    Repr::Rust | Repr::Transparent => unreachable!(),
1164  };
1165  let ident =
1166    Ident::new(&format!("{}Discriminant", input.ident), input.ident.span());
1167  let variants = e.variants.iter().cloned().map(|mut e| {
1168    e.fields = Fields::Unit;
1169    e
1170  });
1171  Ok(quote! {
1172    #repr
1173    #[allow(dead_code)]
1174    enum #ident {
1175      #(#variants,)*
1176    }
1177  })
1178}
1179
1180fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
1181  match tokens.into_iter().next() {
1182    Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
1183    Some(TokenTree::Ident(ident)) => Some(ident),
1184    _ => None,
1185  }
1186}
1187
1188/// get a simple #[foo(bar)] attribute, returning "bar"
1189fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
1190  for attr in attributes {
1191    if let (AttrStyle::Outer, Meta::List(list)) = (&attr.style, &attr.meta) {
1192      if list.path.is_ident(attr_name) {
1193        if let Some(ident) = get_ident_from_stream(list.tokens.clone()) {
1194          return Some(ident);
1195        }
1196      }
1197    }
1198  }
1199
1200  None
1201}
1202
1203fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
1204  attributes
1205    .iter()
1206    .filter_map(|attr| {
1207      if attr.path().is_ident("repr") {
1208        Some(attr.parse_args::<Representation>())
1209      } else {
1210        None
1211      }
1212    })
1213    .try_fold(Representation::default(), |a, b| {
1214      let b = b?;
1215      Ok(Representation {
1216        repr: match (a.repr, b.repr) {
1217          (a, Repr::Rust) => a,
1218          (Repr::Rust, b) => b,
1219          _ => bail!("conflicting representation hints"),
1220        },
1221        packed: match (a.packed, b.packed) {
1222          (a, None) => a,
1223          (None, b) => b,
1224          _ => bail!("conflicting representation hints"),
1225        },
1226        align: match (a.align, b.align) {
1227          (Some(a), Some(b)) => Some(cmp::max(a, b)),
1228          (a, None) => a,
1229          (None, b) => b,
1230        },
1231      })
1232    })
1233}
1234
1235mk_repr! {
1236  U8 => u8,
1237  I8 => i8,
1238  U16 => u16,
1239  I16 => i16,
1240  U32 => u32,
1241  I32 => i32,
1242  U64 => u64,
1243  I64 => i64,
1244  I128 => i128,
1245  U128 => u128,
1246  Usize => usize,
1247  Isize => isize,
1248}
1249// where
1250macro_rules! mk_repr {(
1251  $(
1252    $Xn:ident => $xn:ident
1253  ),* $(,)?
1254) => (
1255  #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1256  enum IntegerRepr {
1257    $($Xn),*
1258  }
1259
1260  impl<'a> TryFrom<&'a str> for IntegerRepr {
1261    type Error = &'a str;
1262
1263    fn try_from(value: &'a str) -> std::result::Result<Self, &'a str> {
1264      match value {
1265        $(
1266          stringify!($xn) => Ok(Self::$Xn),
1267        )*
1268        _ => Err(value),
1269      }
1270    }
1271  }
1272
1273  impl ToTokens for IntegerRepr {
1274    fn to_tokens(&self, tokens: &mut TokenStream) {
1275      match self {
1276        $(
1277          Self::$Xn => tokens.extend(quote!($xn)),
1278        )*
1279      }
1280    }
1281  }
1282)}
1283use mk_repr;
1284
1285#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1286enum Repr {
1287  Rust,
1288  C,
1289  Transparent,
1290  Integer(IntegerRepr),
1291  CWithDiscriminant(IntegerRepr),
1292}
1293
1294impl Repr {
1295  fn as_integer(&self) -> Option<IntegerRepr> {
1296    if let Self::Integer(v) = self {
1297      Some(*v)
1298    } else {
1299      None
1300    }
1301  }
1302}
1303
1304#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1305struct Representation {
1306  packed: Option<u32>,
1307  align: Option<u32>,
1308  repr: Repr,
1309}
1310
1311impl Default for Representation {
1312  fn default() -> Self {
1313    Self { packed: None, align: None, repr: Repr::Rust }
1314  }
1315}
1316
1317impl Parse for Representation {
1318  fn parse(input: ParseStream<'_>) -> Result<Representation> {
1319    let mut ret = Representation::default();
1320    while !input.is_empty() {
1321      let keyword = input.parse::<Ident>()?;
1322      // preƫmptively call `.to_string()` *once* (rather than on `is_ident()`)
1323      let keyword_str = keyword.to_string();
1324      let new_repr = match keyword_str.as_str() {
1325        "C" => Repr::C,
1326        "transparent" => Repr::Transparent,
1327        "packed" => {
1328          ret.packed = Some(if input.peek(token::Paren) {
1329            let contents;
1330            parenthesized!(contents in input);
1331            LitInt::base10_parse::<u32>(&contents.parse()?)?
1332          } else {
1333            1
1334          });
1335          let _: Option<Token![,]> = input.parse()?;
1336          continue;
1337        }
1338        "align" => {
1339          let contents;
1340          parenthesized!(contents in input);
1341          let new_align = LitInt::base10_parse::<u32>(&contents.parse()?)?;
1342          ret.align = Some(
1343            ret
1344              .align
1345              .map_or(new_align, |old_align| cmp::max(old_align, new_align)),
1346          );
1347          let _: Option<Token![,]> = input.parse()?;
1348          continue;
1349        }
1350        ident => {
1351          let primitive = IntegerRepr::try_from(ident)
1352            .map_err(|_| input.error("unrecognized representation hint"))?;
1353          Repr::Integer(primitive)
1354        }
1355      };
1356      ret.repr = match (ret.repr, new_repr) {
1357        (Repr::Rust, new_repr) => {
1358          // This is the first explicit repr.
1359          new_repr
1360        }
1361        (Repr::C, Repr::Integer(integer))
1362        | (Repr::Integer(integer), Repr::C) => {
1363          // Both the C repr and an integer repr have been specified
1364          // -> merge into a C wit discriminant.
1365          Repr::CWithDiscriminant(integer)
1366        }
1367        (_, _) => {
1368          return Err(input.error("duplicate representation hint"));
1369        }
1370      };
1371      let _: Option<Token![,]> = input.parse()?;
1372    }
1373    Ok(ret)
1374  }
1375}
1376
1377impl ToTokens for Representation {
1378  fn to_tokens(&self, tokens: &mut TokenStream) {
1379    let mut meta = Punctuated::<_, Token![,]>::new();
1380
1381    match self.repr {
1382      Repr::Rust => {}
1383      Repr::C => meta.push(quote!(C)),
1384      Repr::Transparent => meta.push(quote!(transparent)),
1385      Repr::Integer(primitive) => meta.push(quote!(#primitive)),
1386      Repr::CWithDiscriminant(primitive) => {
1387        meta.push(quote!(C));
1388        meta.push(quote!(#primitive));
1389      }
1390    }
1391
1392    if let Some(packed) = self.packed.as_ref() {
1393      let lit = LitInt::new(&packed.to_string(), Span::call_site());
1394      meta.push(quote!(packed(#lit)));
1395    }
1396
1397    if let Some(align) = self.align.as_ref() {
1398      let lit = LitInt::new(&align.to_string(), Span::call_site());
1399      meta.push(quote!(align(#lit)));
1400    }
1401
1402    tokens.extend(quote!(
1403      #[repr(#meta)]
1404    ));
1405  }
1406}
1407
1408fn enum_has_fields<'a>(
1409  mut variants: impl Iterator<Item = &'a Variant>,
1410) -> bool {
1411  variants.any(|v| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_)))
1412}
1413
1414struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
1415  inner: I,
1416  last_value: i128,
1417}
1418
1419impl<'a, I: Iterator<Item = &'a Variant> + 'a>
1420  VariantDiscriminantIterator<'a, I>
1421{
1422  fn new(inner: I) -> Self {
1423    VariantDiscriminantIterator { inner, last_value: -1 }
1424  }
1425}
1426
1427impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
1428  for VariantDiscriminantIterator<'a, I>
1429{
1430  type Item = Result<(i128, &'a Variant)>;
1431
1432  fn next(&mut self) -> Option<Self::Item> {
1433    let variant = self.inner.next()?;
1434
1435    if let Some((_, discriminant)) = &variant.discriminant {
1436      let discriminant_value = match parse_int_expr(discriminant) {
1437        Ok(value) => value,
1438        Err(e) => return Some(Err(e)),
1439      };
1440      self.last_value = discriminant_value;
1441    } else {
1442      // If this wraps, then either:
1443      // 1. the enum is using repr(u128), so wrapping is correct
1444      // 2. the enum is using repr(i<=128 or u<128), so the compiler will
1445      //    already emit a "wrapping discriminant" E0370 error.
1446      self.last_value = self.last_value.wrapping_add(1);
1447      // Static assert that there is no integer repr > 128 bits. If that
1448      // changes, the above comment is inaccurate and needs to be updated!
1449      // FIXME(zachs18): maybe should also do something to ensure `isize::BITS
1450      // <= 128`?
1451      if let Some(repr) = None::<IntegerRepr> {
1452        match repr {
1453          IntegerRepr::U8
1454          | IntegerRepr::I8
1455          | IntegerRepr::U16
1456          | IntegerRepr::I16
1457          | IntegerRepr::U32
1458          | IntegerRepr::I32
1459          | IntegerRepr::U64
1460          | IntegerRepr::I64
1461          | IntegerRepr::I128
1462          | IntegerRepr::U128
1463          | IntegerRepr::Usize
1464          | IntegerRepr::Isize => (),
1465        }
1466      }
1467    }
1468
1469    Some(Ok((self.last_value, variant)))
1470  }
1471}
1472
1473fn parse_int_expr(expr: &Expr) -> Result<i128> {
1474  match expr {
1475    Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
1476      parse_int_expr(expr).map(|int| -int)
1477    }
1478    Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(),
1479    Expr::Lit(ExprLit { lit: Lit::Byte(byte), .. }) => Ok(byte.value().into()),
1480    _ => bail!("Not an integer expression"),
1481  }
1482}
1483
1484#[cfg(test)]
1485mod tests {
1486  use syn::parse_quote;
1487
1488  use super::{get_repr, IntegerRepr, Repr, Representation};
1489
1490  #[test]
1491  fn parse_basic_repr() {
1492    let attr = parse_quote!(#[repr(C)]);
1493    let repr = get_repr(&[attr]).unwrap();
1494    assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() });
1495
1496    let attr = parse_quote!(#[repr(transparent)]);
1497    let repr = get_repr(&[attr]).unwrap();
1498    assert_eq!(
1499      repr,
1500      Representation { repr: Repr::Transparent, ..Default::default() }
1501    );
1502
1503    let attr = parse_quote!(#[repr(u8)]);
1504    let repr = get_repr(&[attr]).unwrap();
1505    assert_eq!(
1506      repr,
1507      Representation {
1508        repr: Repr::Integer(IntegerRepr::U8),
1509        ..Default::default()
1510      }
1511    );
1512
1513    let attr = parse_quote!(#[repr(packed)]);
1514    let repr = get_repr(&[attr]).unwrap();
1515    assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1516
1517    let attr = parse_quote!(#[repr(packed(1))]);
1518    let repr = get_repr(&[attr]).unwrap();
1519    assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1520
1521    let attr = parse_quote!(#[repr(packed(2))]);
1522    let repr = get_repr(&[attr]).unwrap();
1523    assert_eq!(repr, Representation { packed: Some(2), ..Default::default() });
1524
1525    let attr = parse_quote!(#[repr(align(2))]);
1526    let repr = get_repr(&[attr]).unwrap();
1527    assert_eq!(repr, Representation { align: Some(2), ..Default::default() });
1528  }
1529
1530  #[test]
1531  fn parse_advanced_repr() {
1532    let attr = parse_quote!(#[repr(align(4), align(2))]);
1533    let repr = get_repr(&[attr]).unwrap();
1534    assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1535
1536    let attr1 = parse_quote!(#[repr(align(1))]);
1537    let attr2 = parse_quote!(#[repr(align(4))]);
1538    let attr3 = parse_quote!(#[repr(align(2))]);
1539    let repr = get_repr(&[attr1, attr2, attr3]).unwrap();
1540    assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1541
1542    let attr = parse_quote!(#[repr(C, u8)]);
1543    let repr = get_repr(&[attr]).unwrap();
1544    assert_eq!(
1545      repr,
1546      Representation {
1547        repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1548        ..Default::default()
1549      }
1550    );
1551
1552    let attr = parse_quote!(#[repr(u8, C)]);
1553    let repr = get_repr(&[attr]).unwrap();
1554    assert_eq!(
1555      repr,
1556      Representation {
1557        repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1558        ..Default::default()
1559      }
1560    );
1561  }
1562}
1563
1564pub fn bytemuck_crate_name(input: &DeriveInput) -> TokenStream {
1565  const ATTR_NAME: &'static str = "crate";
1566
1567  let mut crate_name = quote!(::bytemuck);
1568  for attr in &input.attrs {
1569    if !attr.path().is_ident("bytemuck") {
1570      continue;
1571    }
1572
1573    attr.parse_nested_meta(|meta| {
1574      if meta.path.is_ident(ATTR_NAME) {
1575        let expr: syn::Expr = meta.value()?.parse()?;
1576        let mut value = &expr;
1577        while let syn::Expr::Group(e) = value {
1578          value = &e.expr;
1579        }
1580        if let syn::Expr::Lit(syn::ExprLit {
1581          lit: syn::Lit::Str(lit), ..
1582        }) = value
1583        {
1584          let suffix = lit.suffix();
1585          if !suffix.is_empty() {
1586            bail!(format!("Unexpected suffix `{}` on string literal", suffix))
1587          }
1588          let path: syn::Path = match lit.parse() {
1589            Ok(path) => path,
1590            Err(_) => {
1591              bail!(format!("Failed to parse path: {:?}", lit.value()))
1592            }
1593          };
1594          crate_name = path.into_token_stream();
1595        } else {
1596          bail!(
1597            "Expected bytemuck `crate` attribute to be a string: `crate = \"...\"`",
1598          )
1599        }
1600      }
1601      Ok(())
1602    }).unwrap();
1603  }
1604
1605  return crate_name;
1606}
1607
1608const GENERATED_TYPE_DOCUMENTATION: &str =
1609  " `bytemuck`-generated type for internal purposes only.";