1#![allow(unused_imports)]
2use std::{cmp, convert::TryFrom};
3
4use proc_macro2::{Ident, Span, TokenStream, TokenTree};
5use quote::{quote, ToTokens};
6use syn::{
7 parse::{Parse, ParseStream, Parser},
8 punctuated::Punctuated,
9 spanned::Spanned,
10 Result, *,
11};
12
13macro_rules! bail {
14 ($msg:expr $(,)?) => {
15 return Err(Error::new(Span::call_site(), &$msg[..]))
16 };
17
18 ( $msg:expr => $span_to_blame:expr $(,)? ) => {
19 return Err(Error::new_spanned(&$span_to_blame, $msg))
20 };
21}
22
23pub trait Derivable {
24 fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>;
25 fn implies_trait(_crate_name: &TokenStream) -> Option<TokenStream> {
26 None
27 }
28 fn asserts(
29 _input: &DeriveInput, _crate_name: &TokenStream,
30 ) -> Result<TokenStream> {
31 Ok(quote!())
32 }
33 fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> {
34 Ok(())
35 }
36 fn trait_impl(
37 _input: &DeriveInput, _crate_name: &TokenStream,
38 ) -> Result<(TokenStream, TokenStream)> {
39 Ok((quote!(), quote!()))
40 }
41 fn requires_where_clause() -> bool {
42 true
43 }
44 fn explicit_bounds_attribute_name() -> Option<&'static str> {
45 None
46 }
47
48 fn perfect_derive_fields(_input: &DeriveInput) -> Option<Fields> {
53 None
54 }
55}
56
57pub struct Pod;
58
59impl Derivable for Pod {
60 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
61 Ok(syn::parse_quote!(#crate_name::Pod))
62 }
63
64 fn asserts(
65 input: &DeriveInput, crate_name: &TokenStream,
66 ) -> Result<TokenStream> {
67 let repr = get_repr(&input.attrs)?;
68
69 let completly_packed =
70 repr.packed == Some(1) || repr.repr == Repr::Transparent;
71
72 if !completly_packed && !input.generics.params.is_empty() {
73 bail!("\
74 Pod requires cannot be derived for non-packed types containing \
75 generic parameters because the padding requirements can't be verified \
76 for generic non-packed structs\
77 " => input.generics.params.first().unwrap());
78 }
79
80 match &input.data {
81 Data::Struct(_) => {
82 let assert_no_padding = if !completly_packed {
83 Some(generate_assert_no_padding(input)?)
84 } else {
85 None
86 };
87 let assert_fields_are_pod = generate_fields_are_trait(
88 input,
89 None,
90 Self::ident(input, crate_name)?,
91 )?;
92
93 Ok(quote!(
94 #assert_no_padding
95 #assert_fields_are_pod
96 ))
97 }
98 Data::Enum(_) => bail!("Deriving Pod is not supported for enums"),
99 Data::Union(_) => bail!("Deriving Pod is not supported for unions"),
100 }
101 }
102
103 fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
104 let repr = get_repr(attributes)?;
105 match repr.repr {
106 Repr::C => Ok(()),
107 Repr::Transparent => Ok(()),
108 _ => {
109 bail!("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
110 }
111 }
112 }
113}
114
115pub struct AnyBitPattern;
116
117impl Derivable for AnyBitPattern {
118 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
119 Ok(syn::parse_quote!(#crate_name::AnyBitPattern))
120 }
121
122 fn implies_trait(crate_name: &TokenStream) -> Option<TokenStream> {
123 Some(quote!(#crate_name::Zeroable))
124 }
125
126 fn asserts(
127 input: &DeriveInput, crate_name: &TokenStream,
128 ) -> Result<TokenStream> {
129 match &input.data {
130 Data::Union(_) => Ok(quote!()), Data::Struct(_) => {
132 generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
133 }
134 Data::Enum(_) => {
135 bail!("Deriving AnyBitPattern is not supported for enums")
136 }
137 }
138 }
139}
140
141pub struct Zeroable;
142
143fn get_zero_variant(enum_: &DataEnum) -> Result<Option<&Variant>> {
146 let iter = VariantDiscriminantIterator::new(enum_.variants.iter());
147 let mut zero_variant = None;
148 for res in iter {
149 let (discriminant, variant) = res?;
150 if discriminant == 0 {
151 zero_variant = Some(variant);
152 break;
153 }
154 }
155 Ok(zero_variant)
156}
157
158impl Derivable for Zeroable {
159 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
160 Ok(syn::parse_quote!(#crate_name::Zeroable))
161 }
162
163 fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
164 let repr = get_repr(attributes)?;
165 match ty {
166 Data::Struct(_) => Ok(()),
167 Data::Enum(_) => {
168 if !matches!(
169 repr.repr,
170 Repr::C | Repr::Integer(_) | Repr::CWithDiscriminant(_)
171 ) {
172 bail!("Zeroable requires the enum to be an explicit #[repr(Int)] and/or #[repr(C)]")
173 }
174
175 Ok(())
179 }
180 Data::Union(_) => Ok(()),
181 }
182 }
183
184 fn asserts(
185 input: &DeriveInput, crate_name: &TokenStream,
186 ) -> Result<TokenStream> {
187 match &input.data {
188 Data::Union(_) => Ok(quote!()), Data::Struct(_) => {
190 generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
191 }
192 Data::Enum(enum_) => {
193 let zero_variant = get_zero_variant(enum_)?;
194
195 if zero_variant.is_none() {
196 bail!("No variant's discriminant is 0")
197 };
198
199 generate_fields_are_trait(
200 input,
201 zero_variant,
202 Self::ident(input, crate_name)?,
203 )
204 }
205 }
206 }
207
208 fn explicit_bounds_attribute_name() -> Option<&'static str> {
209 Some("zeroable")
210 }
211
212 fn perfect_derive_fields(input: &DeriveInput) -> Option<Fields> {
213 match &input.data {
214 Data::Struct(struct_) => Some(struct_.fields.clone()),
215 Data::Enum(enum_) => {
216 Some(get_zero_variant(enum_).ok()??.fields.clone())
220 }
221 Data::Union(_) => None,
222 }
223 }
224}
225
226pub struct NoUninit;
227
228impl Derivable for NoUninit {
229 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
230 Ok(syn::parse_quote!(#crate_name::NoUninit))
231 }
232
233 fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
234 let repr = get_repr(attributes)?;
235 match ty {
236 Data::Struct(_) => match repr.repr {
237 Repr::C | Repr::Transparent => Ok(()),
238 _ => bail!("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"),
239 },
240 Data::Enum(_) => if repr.repr.is_integer() {
241 Ok(())
242 } else {
243 bail!("NoUninit requires the enum to be an explicit #[repr(Int)]")
244 },
245 Data::Union(_) => bail!("NoUninit can only be derived on enums and structs")
246 }
247 }
248
249 fn asserts(
250 input: &DeriveInput, crate_name: &TokenStream,
251 ) -> Result<TokenStream> {
252 if !input.generics.params.is_empty() {
253 bail!("NoUninit cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
254 }
255
256 match &input.data {
257 Data::Struct(DataStruct { .. }) => {
258 let assert_no_padding = generate_assert_no_padding(&input)?;
259 let assert_fields_are_no_padding = generate_fields_are_trait(
260 &input,
261 None,
262 Self::ident(input, crate_name)?,
263 )?;
264
265 Ok(quote!(
266 #assert_no_padding
267 #assert_fields_are_no_padding
268 ))
269 }
270 Data::Enum(DataEnum { variants, .. }) => {
271 if variants.iter().any(|variant| !variant.fields.is_empty()) {
272 bail!("Only fieldless enums are supported for NoUninit")
273 } else {
274 Ok(quote!())
275 }
276 }
277 Data::Union(_) => bail!("NoUninit cannot be derived for unions"), }
279 }
280
281 fn trait_impl(
282 _input: &DeriveInput, _crate_name: &TokenStream,
283 ) -> Result<(TokenStream, TokenStream)> {
284 Ok((quote!(), quote!()))
285 }
286}
287
288pub struct CheckedBitPattern;
289
290impl Derivable for CheckedBitPattern {
291 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
292 Ok(syn::parse_quote!(#crate_name::CheckedBitPattern))
293 }
294
295 fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
296 let repr = get_repr(attributes)?;
297 match ty {
298 Data::Struct(_) => match repr.repr {
299 Repr::C | Repr::Transparent => Ok(()),
300 _ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"),
301 },
302 Data::Enum(DataEnum { variants,.. }) => {
303 if !enum_has_fields(variants.iter()){
304 if repr.repr.is_integer() {
305 Ok(())
306 } else {
307 bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]")
308 }
309 } else if matches!(repr.repr, Repr::Rust) {
310 bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
311 } else {
312 Ok(())
313 }
314 }
315 Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs")
316 }
317 }
318
319 fn asserts(
320 input: &DeriveInput, crate_name: &TokenStream,
321 ) -> Result<TokenStream> {
322 if !input.generics.params.is_empty() {
323 bail!("CheckedBitPattern cannot be derived for structs containing generic parameters");
324 }
325
326 match &input.data {
327 Data::Struct(DataStruct { .. }) => {
328 let assert_fields_are_maybe_pod = generate_fields_are_trait(
329 &input,
330 None,
331 Self::ident(input, crate_name)?,
332 )?;
333
334 Ok(assert_fields_are_maybe_pod)
335 }
336 Data::Enum(_) => Ok(quote!()),
338 Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), }
340 }
341
342 fn trait_impl(
343 input: &DeriveInput, crate_name: &TokenStream,
344 ) -> Result<(TokenStream, TokenStream)> {
345 match &input.data {
346 Data::Struct(DataStruct { fields, .. }) => {
347 generate_checked_bit_pattern_struct(
348 &input.ident,
349 fields,
350 &input.attrs,
351 crate_name,
352 )
353 }
354 Data::Enum(DataEnum { variants, .. }) => {
355 generate_checked_bit_pattern_enum(input, variants, crate_name)
356 }
357 Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), }
359 }
360}
361
362pub struct TransparentWrapper;
363
364impl TransparentWrapper {
365 fn get_wrapper_type(
366 attributes: &[Attribute], fields: &Fields,
367 ) -> Option<TokenStream> {
368 let transparent_param = get_simple_attr(attributes, "transparent");
369 transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
370 let mut types = get_field_types(&fields);
371 let first_type = types.next();
372 if let Some(_) = types.next() {
373 return None;
375 } else {
376 first_type.map(|ty| ty.to_token_stream())
377 }
378 })
379 }
380}
381
382impl Derivable for TransparentWrapper {
383 fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
384 let fields = get_struct_fields(input)?;
385
386 let ty = match Self::get_wrapper_type(&input.attrs, &fields) {
387 Some(ty) => ty,
388 None => bail!(
389 "\
390 when deriving TransparentWrapper for a struct with more than one field \
391 you need to specify the transparent field using #[transparent(T)]\
392 "
393 ),
394 };
395
396 Ok(syn::parse_quote!(#crate_name::TransparentWrapper<#ty>))
397 }
398
399 fn asserts(
400 input: &DeriveInput, crate_name: &TokenStream,
401 ) -> Result<TokenStream> {
402 let (impl_generics, _ty_generics, where_clause) =
403 input.generics.split_for_impl();
404 let fields = get_struct_fields(input)?;
405 let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
406 Some(wrapped_type) => wrapped_type.to_string(),
407 None => unreachable!(), };
409 let mut wrapped_field_ty = None;
410 let mut nonwrapped_field_tys = vec![];
411 for field in fields.iter() {
412 let field_ty = &field.ty;
413 if field_ty.to_token_stream().to_string() == wrapped_type {
414 if wrapped_field_ty.is_some() {
415 bail!(
416 "TransparentWrapper can only have one field of the wrapped type"
417 );
418 }
419 wrapped_field_ty = Some(field_ty);
420 } else {
421 nonwrapped_field_tys.push(field_ty);
422 }
423 }
424 if let Some(wrapped_field_ty) = wrapped_field_ty {
425 Ok(quote!(
426 const _: () = {
427 #[repr(transparent)]
428 #[allow(clippy::multiple_bound_locations)]
429 struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
430 fn assert_zeroable<Z: #crate_name::Zeroable>() {}
431 #[allow(clippy::multiple_bound_locations)]
432 fn check #impl_generics () #where_clause {
433 #(
434 assert_zeroable::<#nonwrapped_field_tys>();
435 )*
436 }
437 };
438 ))
439 } else {
440 bail!("TransparentWrapper must have one field of the wrapped type")
441 }
442 }
443
444 fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
445 let repr = get_repr(attributes)?;
446
447 match repr.repr {
448 Repr::Transparent => Ok(()),
449 _ => {
450 bail!(
451 "TransparentWrapper requires the struct to be #[repr(transparent)]"
452 )
453 }
454 }
455 }
456
457 fn requires_where_clause() -> bool {
458 false
459 }
460}
461
462pub struct Contiguous;
463
464impl Derivable for Contiguous {
465 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
466 Ok(syn::parse_quote!(#crate_name::Contiguous))
467 }
468
469 fn trait_impl(
470 input: &DeriveInput, _crate_name: &TokenStream,
471 ) -> Result<(TokenStream, TokenStream)> {
472 let repr = get_repr(&input.attrs)?;
473
474 let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() {
475 integer_ty
476 } else {
477 bail!("Contiguous requires the enum to be #[repr(Int)]");
478 };
479
480 let variants = get_enum_variants(input)?;
481 if enum_has_fields(variants.clone()) {
482 return Err(Error::new_spanned(
483 &input,
484 "Only fieldless enums are supported",
485 ));
486 }
487
488 let mut variants_with_discriminant =
489 VariantDiscriminantIterator::new(variants);
490
491 let (min, max, count) = variants_with_discriminant.try_fold(
492 (i128::MAX, i128::MIN, 0),
493 |(min, max, count), res| {
494 let (discriminant, _variant) = res?;
495 Ok::<_, Error>((
496 i128::min(min, discriminant),
497 i128::max(max, discriminant),
498 count + 1,
499 ))
500 },
501 )?;
502
503 if max - min != count - 1 {
504 bail! {
505 "Contiguous requires the enum discriminants to be contiguous",
506 }
507 }
508
509 let min_lit = LitInt::new(&format!("{}", min), input.span());
510 let max_lit = LitInt::new(&format!("{}", max), input.span());
511
512 Ok((
517 quote!(),
518 quote! {
519 type Int = #integer_ty;
520
521 #[allow(clippy::missing_docs_in_private_items)]
522 const MIN_VALUE: #integer_ty = #min_lit;
523
524 #[allow(clippy::missing_docs_in_private_items)]
525 const MAX_VALUE: #integer_ty = #max_lit;
526
527 #[inline]
528 fn from_integer(value: Self::Int) -> Option<Self> {
529 #[allow(clippy::manual_range_contains)]
530 if Self::MIN_VALUE <= value && value <= Self::MAX_VALUE {
531 Some(unsafe { ::core::mem::transmute(value) })
532 } else {
533 None
534 }
535 }
536
537 #[inline]
538 fn into_integer(self) -> Self::Int {
539 self as #integer_ty
540 }
541 },
542 ))
543 }
544}
545
546fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> {
547 if let Data::Struct(DataStruct { fields, .. }) = &input.data {
548 Ok(fields)
549 } else {
550 bail!("deriving this trait is only supported for structs")
551 }
552}
553
554fn get_fields(
560 input: &DeriveInput, enum_variant: Option<&Variant>,
561) -> Result<Fields> {
562 match &input.data {
563 Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
564 Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
565 Data::Enum(_) => match enum_variant {
566 Some(variant) => Ok(variant.fields.clone()),
567 None => bail!("deriving this trait is not supported for enums"),
568 },
569 }
570}
571
572fn get_enum_variants<'a>(
573 input: &'a DeriveInput,
574) -> Result<impl Iterator<Item = &'a Variant> + Clone + 'a> {
575 if let Data::Enum(DataEnum { variants, .. }) = &input.data {
576 Ok(variants.iter())
577 } else {
578 bail!("deriving this trait is only supported for enums")
579 }
580}
581
582fn get_field_types<'a>(
583 fields: &'a Fields,
584) -> impl Iterator<Item = &'a Type> + 'a {
585 fields.iter().map(|field| &field.ty)
586}
587
588fn generate_checked_bit_pattern_struct(
589 input_ident: &Ident, fields: &Fields, attrs: &[Attribute],
590 crate_name: &TokenStream,
591) -> Result<(TokenStream, TokenStream)> {
592 let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span());
593
594 let repr = get_repr(attrs)?;
595
596 let field_names = fields
597 .iter()
598 .enumerate()
599 .map(|(i, field)| {
600 field.ident.clone().unwrap_or_else(|| {
601 Ident::new(&format!("field{}", i), input_ident.span())
602 })
603 })
604 .collect::<Vec<_>>();
605 let field_tys = fields.iter().map(|field| &field.ty).collect::<Vec<_>>();
606
607 let field_name = &field_names[..];
608 let field_ty = &field_tys[..];
609
610 Ok((
611 quote! {
612 #[doc = #GENERATED_TYPE_DOCUMENTATION]
613 #repr
614 #[derive(Clone, Copy, #crate_name::AnyBitPattern)]
615 #[allow(missing_docs)]
616 pub struct #bits_ty {
617 #(#field_name: <#field_ty as #crate_name::CheckedBitPattern>::Bits,)*
618 }
619
620 #[allow(unexpected_cfgs)]
621 const _: () = {
622 #[cfg(not(target_arch = "spirv"))]
623 impl ::core::fmt::Debug for #bits_ty {
624 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
625 let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty));
626 #(::core::fmt::DebugStruct::field(&mut debug_struct, ::core::stringify!(#field_name), &{ self.#field_name });)*
627 ::core::fmt::DebugStruct::finish(&mut debug_struct)
628 }
629 }
630 };
631 },
632 quote! {
633 type Bits = #bits_ty;
634
635 #[inline]
636 #[allow(clippy::double_comparisons, unused)]
637 fn is_valid_bit_pattern(bits: &#bits_ty) -> bool {
638 #(<#field_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(&{ bits.#field_name }) && )* true
639 }
640 },
641 ))
642}
643
644fn generate_checked_bit_pattern_enum(
645 input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
646 crate_name: &TokenStream,
647) -> Result<(TokenStream, TokenStream)> {
648 if enum_has_fields(variants.iter()) {
649 generate_checked_bit_pattern_enum_with_fields(input, variants, crate_name)
650 } else {
651 generate_checked_bit_pattern_enum_without_fields(input, variants)
652 }
653}
654
655fn generate_checked_bit_pattern_enum_without_fields(
656 input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
657) -> Result<(TokenStream, TokenStream)> {
658 let span = input.span();
659 let mut variants_with_discriminant =
660 VariantDiscriminantIterator::new(variants.iter());
661
662 let (min, max, count) = variants_with_discriminant.try_fold(
663 (i128::MAX, i128::MIN, 0),
664 |(min, max, count), res| {
665 let (discriminant, _variant) = res?;
666 Ok::<_, Error>((
667 i128::min(min, discriminant),
668 i128::max(max, discriminant),
669 count + 1,
670 ))
671 },
672 )?;
673
674 let check = if count == 0 {
675 quote!(false)
676 } else if max - min == count - 1 {
677 let min_lit = LitInt::new(&format!("{}", min), span);
679 let max_lit = LitInt::new(&format!("{}", max), span);
680
681 quote!(*bits >= #min_lit && *bits <= #max_lit)
682 } else {
683 let variant_discriminant_lits =
685 VariantDiscriminantIterator::new(variants.iter())
686 .map(|res| {
687 let (discriminant, _variant) = res?;
688 Ok(LitInt::new(&format!("{}", discriminant), span))
689 })
690 .collect::<Result<Vec<_>>>()?;
691
692 let first = &variant_discriminant_lits[0];
694 let rest = &variant_discriminant_lits[1..];
695
696 quote!(matches!(*bits, #first #(| #rest )*))
697 };
698
699 let repr = get_repr(&input.attrs)?;
700 let integer = repr.repr.as_integer().unwrap(); Ok((
702 quote!(),
703 quote! {
704 type Bits = #integer;
705
706 #[inline]
707 #[allow(clippy::double_comparisons)]
708 fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
709 #check
710 }
711 },
712 ))
713}
714
715fn generate_checked_bit_pattern_enum_with_fields(
716 input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
717 crate_name: &TokenStream,
718) -> Result<(TokenStream, TokenStream)> {
719 let representation = get_repr(&input.attrs)?;
720 let vis = &input.vis;
721
722 match representation.repr {
723 Repr::Rust => unreachable!(),
724 repr @ (Repr::C | Repr::CWithDiscriminant(_)) => {
725 let integer = match repr {
726 Repr::C => quote!(::core::ffi::c_int),
727 Repr::CWithDiscriminant(integer) => quote!(#integer),
728 _ => unreachable!(),
729 };
730 let input_ident = &input.ident;
731
732 let bits_repr = Representation { repr: Repr::C, ..representation };
733
734 let bits_ty_ident =
739 Ident::new(&format!("{input_ident}Bits"), input.span());
740
741 let variants_union_ident =
751 Ident::new(&format!("{}Variants", input.ident), input.span());
752
753 let variant_struct_idents = variants.iter().map(|v| {
754 Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
755 });
756
757 let variant_struct_definitions =
758 variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
759 let fields = v.fields.iter().map(|v| &v.ty);
760
761 quote! {
762 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
763 #[repr(C)]
764 #vis struct #variant_struct_ident(#(#fields),*);
765 }
766 });
767
768 let union_fields = variant_struct_idents
769 .clone()
770 .zip(variants.iter())
771 .map(|(variant_struct_ident, v)| {
772 let variant_struct_bits_ident =
773 Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
774 let field_ident = &v.ident;
775 quote! {
776 #field_ident: #variant_struct_bits_ident
777 }
778 });
779
780 let variant_checks = variant_struct_idents
781 .clone()
782 .zip(VariantDiscriminantIterator::new(variants.iter()))
783 .zip(variants.iter())
784 .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
785 let (discriminant, _variant) = discriminant?;
786 let discriminant = LitInt::new(&discriminant.to_string(), v.span());
787 let ident = &v.ident;
788 Ok(quote! {
789 #discriminant => {
790 let payload = unsafe { &bits.payload.#ident };
791 <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
792 }
793 })
794 })
795 .collect::<Result<Vec<_>>>()?;
796
797 Ok((
798 quote! {
799 #[doc = #GENERATED_TYPE_DOCUMENTATION]
800 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
801 #bits_repr
802 #vis struct #bits_ty_ident {
803 tag: #integer,
804 payload: #variants_union_ident,
805 }
806
807 #[allow(unexpected_cfgs)]
808 const _: () = {
809 #[cfg(not(target_arch = "spirv"))]
810 impl ::core::fmt::Debug for #bits_ty_ident {
811 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
812 let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
813 ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", &self.tag);
814 ::core::fmt::DebugStruct::field(&mut debug_struct, "payload", &self.payload);
815 ::core::fmt::DebugStruct::finish(&mut debug_struct)
816 }
817 }
818 };
819
820 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
821 #[repr(C)]
822 #[allow(non_snake_case)]
823 #vis union #variants_union_ident {
824 #(#union_fields,)*
825 }
826
827 #[allow(unexpected_cfgs)]
828 const _: () = {
829 #[cfg(not(target_arch = "spirv"))]
830 impl ::core::fmt::Debug for #variants_union_ident {
831 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
832 let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident));
833 ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
834 }
835 }
836 };
837
838 #(#variant_struct_definitions)*
839 },
840 quote! {
841 type Bits = #bits_ty_ident;
842
843 #[inline]
844 #[allow(clippy::double_comparisons)]
845 fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
846 match bits.tag {
847 #(#variant_checks)*
848 _ => false,
849 }
850 }
851 },
852 ))
853 }
854 Repr::Transparent => {
855 if variants.len() != 1 {
856 bail!("enums with more than one variant cannot be transparent")
857 }
858
859 let variant = &variants[0];
860
861 let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span());
862 let fields = variant.fields.iter().map(|v| &v.ty);
863
864 Ok((
865 quote! {
866 #[doc = #GENERATED_TYPE_DOCUMENTATION]
867 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
868 #[repr(C)]
869 #vis struct #bits_ty(#(#fields),*);
870 },
871 quote! {
872 type Bits = <#bits_ty as #crate_name::CheckedBitPattern>::Bits;
873
874 #[inline]
875 #[allow(clippy::double_comparisons)]
876 fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
877 <#bits_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(bits)
878 }
879 },
880 ))
881 }
882 Repr::Integer(integer) => {
883 let bits_repr = Representation { repr: Repr::C, ..representation };
884 let input_ident = &input.ident;
885
886 let bits_ty_ident =
895 Ident::new(&format!("{input_ident}Bits"), input.span());
896
897 let variant_struct_idents = variants.iter().map(|v| {
898 Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
899 });
900
901 let variant_struct_definitions =
902 variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
903 let fields = v.fields.iter().map(|v| &v.ty);
904
905 quote! {
907 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
908 #[repr(C)]
909 #vis struct #variant_struct_ident(#integer, #(#fields),*);
910 }
911 });
912
913 let union_fields = variant_struct_idents
914 .clone()
915 .zip(variants.iter())
916 .map(|(variant_struct_ident, v)| {
917 let variant_struct_bits_ident =
918 Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
919 let field_ident = &v.ident;
920 quote! {
921 #field_ident: #variant_struct_bits_ident
922 }
923 });
924
925 let variant_checks = variant_struct_idents
926 .clone()
927 .zip(VariantDiscriminantIterator::new(variants.iter()))
928 .zip(variants.iter())
929 .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
930 let (discriminant, _variant) = discriminant?;
931 let discriminant = LitInt::new(&discriminant.to_string(), v.span());
932 let ident = &v.ident;
933 Ok(quote! {
934 #discriminant => {
935 let payload = unsafe { &bits.#ident };
936 <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
937 }
938 })
939 })
940 .collect::<Result<Vec<_>>>()?;
941
942 Ok((
943 quote! {
944 #[doc = #GENERATED_TYPE_DOCUMENTATION]
945 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
946 #bits_repr
947 #[allow(non_snake_case)]
948 #vis union #bits_ty_ident {
949 __tag: #integer,
950 #(#union_fields,)*
951 }
952
953 #[allow(unexpected_cfgs)]
954 const _: () = {
955 #[cfg(not(target_arch = "spirv"))]
956 impl ::core::fmt::Debug for #bits_ty_ident {
957 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
958 let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
959 ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag });
960 ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
961 }
962 }
963 };
964
965 #(#variant_struct_definitions)*
966 },
967 quote! {
968 type Bits = #bits_ty_ident;
969
970 #[inline]
971 #[allow(clippy::double_comparisons)]
972 fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
973 match unsafe { bits.__tag } {
974 #(#variant_checks)*
975 _ => false,
976 }
977 }
978 },
979 ))
980 }
981 }
982}
983
984fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
987 let struct_type = &input.ident;
988 let enum_variant = None; let fields = get_fields(input, enum_variant)?;
990
991 let mut field_types = get_field_types(&fields);
992 let size_sum = if let Some(first) = field_types.next() {
993 let size_first = quote!(::core::mem::size_of::<#first>());
994 let size_rest = quote!(#( + ::core::mem::size_of::<#field_types>() )*);
995
996 quote!(#size_first #size_rest)
997 } else {
998 quote!(0)
999 };
1000
1001 Ok(quote! {const _: fn() = || {
1002 #[doc(hidden)]
1003 struct TypeWithoutPadding([u8; #size_sum]);
1004 let _ = ::core::mem::transmute::<#struct_type, TypeWithoutPadding>;
1005 };})
1006}
1007
1008fn generate_fields_are_trait(
1010 input: &DeriveInput, enum_variant: Option<&Variant>, trait_: syn::Path,
1011) -> Result<TokenStream> {
1012 let (impl_generics, _ty_generics, where_clause) =
1013 input.generics.split_for_impl();
1014 let fields = get_fields(input, enum_variant)?;
1015 let field_types = get_field_types(&fields);
1016 Ok(quote! {#(const _: fn() = || {
1017 #[allow(clippy::missing_const_for_fn)]
1018 #[doc(hidden)]
1019 fn check #impl_generics () #where_clause {
1020 fn assert_impl<T: #trait_>() {}
1021 assert_impl::<#field_types>();
1022 }
1023 };)*
1024 })
1025}
1026
1027fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
1028 match tokens.into_iter().next() {
1029 Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
1030 Some(TokenTree::Ident(ident)) => Some(ident),
1031 _ => None,
1032 }
1033}
1034
1035fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
1037 for attr in attributes {
1038 if let (AttrStyle::Outer, Meta::List(list)) = (&attr.style, &attr.meta) {
1039 if list.path.is_ident(attr_name) {
1040 if let Some(ident) = get_ident_from_stream(list.tokens.clone()) {
1041 return Some(ident);
1042 }
1043 }
1044 }
1045 }
1046
1047 None
1048}
1049
1050fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
1051 attributes
1052 .iter()
1053 .filter_map(|attr| {
1054 if attr.path().is_ident("repr") {
1055 Some(attr.parse_args::<Representation>())
1056 } else {
1057 None
1058 }
1059 })
1060 .try_fold(Representation::default(), |a, b| {
1061 let b = b?;
1062 Ok(Representation {
1063 repr: match (a.repr, b.repr) {
1064 (a, Repr::Rust) => a,
1065 (Repr::Rust, b) => b,
1066 _ => bail!("conflicting representation hints"),
1067 },
1068 packed: match (a.packed, b.packed) {
1069 (a, None) => a,
1070 (None, b) => b,
1071 _ => bail!("conflicting representation hints"),
1072 },
1073 align: match (a.align, b.align) {
1074 (Some(a), Some(b)) => Some(cmp::max(a, b)),
1075 (a, None) => a,
1076 (None, b) => b,
1077 },
1078 })
1079 })
1080}
1081
1082mk_repr! {
1083 U8 => u8,
1084 I8 => i8,
1085 U16 => u16,
1086 I16 => i16,
1087 U32 => u32,
1088 I32 => i32,
1089 U64 => u64,
1090 I64 => i64,
1091 I128 => i128,
1092 U128 => u128,
1093 Usize => usize,
1094 Isize => isize,
1095}
1096macro_rules! mk_repr {(
1098 $(
1099 $Xn:ident => $xn:ident
1100 ),* $(,)?
1101) => (
1102 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1103 enum IntegerRepr {
1104 $($Xn),*
1105 }
1106
1107 impl<'a> TryFrom<&'a str> for IntegerRepr {
1108 type Error = &'a str;
1109
1110 fn try_from(value: &'a str) -> std::result::Result<Self, &'a str> {
1111 match value {
1112 $(
1113 stringify!($xn) => Ok(Self::$Xn),
1114 )*
1115 _ => Err(value),
1116 }
1117 }
1118 }
1119
1120 impl ToTokens for IntegerRepr {
1121 fn to_tokens(&self, tokens: &mut TokenStream) {
1122 match self {
1123 $(
1124 Self::$Xn => tokens.extend(quote!($xn)),
1125 )*
1126 }
1127 }
1128 }
1129)}
1130use mk_repr;
1131
1132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1133enum Repr {
1134 Rust,
1135 C,
1136 Transparent,
1137 Integer(IntegerRepr),
1138 CWithDiscriminant(IntegerRepr),
1139}
1140
1141impl Repr {
1142 fn is_integer(&self) -> bool {
1143 matches!(self, Self::Integer(..))
1144 }
1145
1146 fn as_integer(&self) -> Option<IntegerRepr> {
1147 if let Self::Integer(v) = self {
1148 Some(*v)
1149 } else {
1150 None
1151 }
1152 }
1153}
1154
1155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1156struct Representation {
1157 packed: Option<u32>,
1158 align: Option<u32>,
1159 repr: Repr,
1160}
1161
1162impl Default for Representation {
1163 fn default() -> Self {
1164 Self { packed: None, align: None, repr: Repr::Rust }
1165 }
1166}
1167
1168impl Parse for Representation {
1169 fn parse(input: ParseStream<'_>) -> Result<Representation> {
1170 let mut ret = Representation::default();
1171 while !input.is_empty() {
1172 let keyword = input.parse::<Ident>()?;
1173 let keyword_str = keyword.to_string();
1175 let new_repr = match keyword_str.as_str() {
1176 "C" => Repr::C,
1177 "transparent" => Repr::Transparent,
1178 "packed" => {
1179 ret.packed = Some(if input.peek(token::Paren) {
1180 let contents;
1181 parenthesized!(contents in input);
1182 LitInt::base10_parse::<u32>(&contents.parse()?)?
1183 } else {
1184 1
1185 });
1186 let _: Option<Token![,]> = input.parse()?;
1187 continue;
1188 }
1189 "align" => {
1190 let contents;
1191 parenthesized!(contents in input);
1192 let new_align = LitInt::base10_parse::<u32>(&contents.parse()?)?;
1193 ret.align = Some(
1194 ret
1195 .align
1196 .map_or(new_align, |old_align| cmp::max(old_align, new_align)),
1197 );
1198 let _: Option<Token![,]> = input.parse()?;
1199 continue;
1200 }
1201 ident => {
1202 let primitive = IntegerRepr::try_from(ident)
1203 .map_err(|_| input.error("unrecognized representation hint"))?;
1204 Repr::Integer(primitive)
1205 }
1206 };
1207 ret.repr = match (ret.repr, new_repr) {
1208 (Repr::Rust, new_repr) => {
1209 new_repr
1211 }
1212 (Repr::C, Repr::Integer(integer))
1213 | (Repr::Integer(integer), Repr::C) => {
1214 Repr::CWithDiscriminant(integer)
1217 }
1218 (_, _) => {
1219 return Err(input.error("duplicate representation hint"));
1220 }
1221 };
1222 let _: Option<Token![,]> = input.parse()?;
1223 }
1224 Ok(ret)
1225 }
1226}
1227
1228impl ToTokens for Representation {
1229 fn to_tokens(&self, tokens: &mut TokenStream) {
1230 let mut meta = Punctuated::<_, Token![,]>::new();
1231
1232 match self.repr {
1233 Repr::Rust => {}
1234 Repr::C => meta.push(quote!(C)),
1235 Repr::Transparent => meta.push(quote!(transparent)),
1236 Repr::Integer(primitive) => meta.push(quote!(#primitive)),
1237 Repr::CWithDiscriminant(primitive) => {
1238 meta.push(quote!(C));
1239 meta.push(quote!(#primitive));
1240 }
1241 }
1242
1243 if let Some(packed) = self.packed.as_ref() {
1244 let lit = LitInt::new(&packed.to_string(), Span::call_site());
1245 meta.push(quote!(packed(#lit)));
1246 }
1247
1248 if let Some(align) = self.align.as_ref() {
1249 let lit = LitInt::new(&align.to_string(), Span::call_site());
1250 meta.push(quote!(align(#lit)));
1251 }
1252
1253 tokens.extend(quote!(
1254 #[repr(#meta)]
1255 ));
1256 }
1257}
1258
1259fn enum_has_fields<'a>(
1260 mut variants: impl Iterator<Item = &'a Variant>,
1261) -> bool {
1262 variants.any(|v| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_)))
1263}
1264
1265struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
1266 inner: I,
1267 last_value: i128,
1268}
1269
1270impl<'a, I: Iterator<Item = &'a Variant> + 'a>
1271 VariantDiscriminantIterator<'a, I>
1272{
1273 fn new(inner: I) -> Self {
1274 VariantDiscriminantIterator { inner, last_value: -1 }
1275 }
1276}
1277
1278impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
1279 for VariantDiscriminantIterator<'a, I>
1280{
1281 type Item = Result<(i128, &'a Variant)>;
1282
1283 fn next(&mut self) -> Option<Self::Item> {
1284 let variant = self.inner.next()?;
1285
1286 if let Some((_, discriminant)) = &variant.discriminant {
1287 let discriminant_value = match parse_int_expr(discriminant) {
1288 Ok(value) => value,
1289 Err(e) => return Some(Err(e)),
1290 };
1291 self.last_value = discriminant_value;
1292 } else {
1293 self.last_value = self.last_value.wrapping_add(1);
1298 if let Some(repr) = None::<IntegerRepr> {
1303 match repr {
1304 IntegerRepr::U8
1305 | IntegerRepr::I8
1306 | IntegerRepr::U16
1307 | IntegerRepr::I16
1308 | IntegerRepr::U32
1309 | IntegerRepr::I32
1310 | IntegerRepr::U64
1311 | IntegerRepr::I64
1312 | IntegerRepr::I128
1313 | IntegerRepr::U128
1314 | IntegerRepr::Usize
1315 | IntegerRepr::Isize => (),
1316 }
1317 }
1318 }
1319
1320 Some(Ok((self.last_value, variant)))
1321 }
1322}
1323
1324fn parse_int_expr(expr: &Expr) -> Result<i128> {
1325 match expr {
1326 Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
1327 parse_int_expr(expr).map(|int| -int)
1328 }
1329 Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(),
1330 Expr::Lit(ExprLit { lit: Lit::Byte(byte), .. }) => Ok(byte.value().into()),
1331 _ => bail!("Not an integer expression"),
1332 }
1333}
1334
1335#[cfg(test)]
1336mod tests {
1337 use syn::parse_quote;
1338
1339 use super::{get_repr, IntegerRepr, Repr, Representation};
1340
1341 #[test]
1342 fn parse_basic_repr() {
1343 let attr = parse_quote!(#[repr(C)]);
1344 let repr = get_repr(&[attr]).unwrap();
1345 assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() });
1346
1347 let attr = parse_quote!(#[repr(transparent)]);
1348 let repr = get_repr(&[attr]).unwrap();
1349 assert_eq!(
1350 repr,
1351 Representation { repr: Repr::Transparent, ..Default::default() }
1352 );
1353
1354 let attr = parse_quote!(#[repr(u8)]);
1355 let repr = get_repr(&[attr]).unwrap();
1356 assert_eq!(
1357 repr,
1358 Representation {
1359 repr: Repr::Integer(IntegerRepr::U8),
1360 ..Default::default()
1361 }
1362 );
1363
1364 let attr = parse_quote!(#[repr(packed)]);
1365 let repr = get_repr(&[attr]).unwrap();
1366 assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1367
1368 let attr = parse_quote!(#[repr(packed(1))]);
1369 let repr = get_repr(&[attr]).unwrap();
1370 assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1371
1372 let attr = parse_quote!(#[repr(packed(2))]);
1373 let repr = get_repr(&[attr]).unwrap();
1374 assert_eq!(repr, Representation { packed: Some(2), ..Default::default() });
1375
1376 let attr = parse_quote!(#[repr(align(2))]);
1377 let repr = get_repr(&[attr]).unwrap();
1378 assert_eq!(repr, Representation { align: Some(2), ..Default::default() });
1379 }
1380
1381 #[test]
1382 fn parse_advanced_repr() {
1383 let attr = parse_quote!(#[repr(align(4), align(2))]);
1384 let repr = get_repr(&[attr]).unwrap();
1385 assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1386
1387 let attr1 = parse_quote!(#[repr(align(1))]);
1388 let attr2 = parse_quote!(#[repr(align(4))]);
1389 let attr3 = parse_quote!(#[repr(align(2))]);
1390 let repr = get_repr(&[attr1, attr2, attr3]).unwrap();
1391 assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1392
1393 let attr = parse_quote!(#[repr(C, u8)]);
1394 let repr = get_repr(&[attr]).unwrap();
1395 assert_eq!(
1396 repr,
1397 Representation {
1398 repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1399 ..Default::default()
1400 }
1401 );
1402
1403 let attr = parse_quote!(#[repr(u8, C)]);
1404 let repr = get_repr(&[attr]).unwrap();
1405 assert_eq!(
1406 repr,
1407 Representation {
1408 repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1409 ..Default::default()
1410 }
1411 );
1412 }
1413}
1414
1415pub fn bytemuck_crate_name(input: &DeriveInput) -> TokenStream {
1416 const ATTR_NAME: &'static str = "crate";
1417
1418 let mut crate_name = quote!(::bytemuck);
1419 for attr in &input.attrs {
1420 if !attr.path().is_ident("bytemuck") {
1421 continue;
1422 }
1423
1424 attr.parse_nested_meta(|meta| {
1425 if meta.path.is_ident(ATTR_NAME) {
1426 let expr: syn::Expr = meta.value()?.parse()?;
1427 let mut value = &expr;
1428 while let syn::Expr::Group(e) = value {
1429 value = &e.expr;
1430 }
1431 if let syn::Expr::Lit(syn::ExprLit {
1432 lit: syn::Lit::Str(lit), ..
1433 }) = value
1434 {
1435 let suffix = lit.suffix();
1436 if !suffix.is_empty() {
1437 bail!(format!("Unexpected suffix `{}` on string literal", suffix))
1438 }
1439 let path: syn::Path = match lit.parse() {
1440 Ok(path) => path,
1441 Err(_) => {
1442 bail!(format!("Failed to parse path: {:?}", lit.value()))
1443 }
1444 };
1445 crate_name = path.into_token_stream();
1446 } else {
1447 bail!(
1448 "Expected bytemuck `crate` attribute to be a string: `crate = \"...\"`",
1449 )
1450 }
1451 }
1452 Ok(())
1453 }).unwrap();
1454 }
1455
1456 return crate_name;
1457}
1458
1459const GENERATED_TYPE_DOCUMENTATION: &str =
1460 " `bytemuck`-generated type for internal purposes only.";