neli_proc_macros/
shared.rs

1use std::{any::type_name, collections::HashMap};
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{quote, ToTokens};
6use syn::{
7    parse,
8    parse::Parse,
9    parse_str,
10    punctuated::Punctuated,
11    token::{Add, Colon2},
12    Attribute, Expr, Fields, FieldsNamed, FieldsUnnamed, GenericParam, Generics, Ident, Index,
13    ItemStruct, LifetimeDef, Lit, Meta, MetaNameValue, NestedMeta, Path, PathArguments,
14    PathSegment, Token, TraitBound, TraitBoundModifier, Type, TypeParam, TypeParamBound, Variant,
15};
16
17/// Represents a field as either an identifier or an index.
18pub enum FieldRepr {
19    Index(Index),
20    Ident(Ident),
21}
22
23impl ToTokens for FieldRepr {
24    fn to_tokens(&self, tokens: &mut TokenStream2) {
25        match self {
26            FieldRepr::Index(i) => i.to_tokens(tokens),
27            FieldRepr::Ident(i) => i.to_tokens(tokens),
28        }
29    }
30}
31
32/// Represents the field name, type, and all attributes associated
33/// with this field.
34pub struct FieldInfo {
35    field_name: FieldRepr,
36    field_type: Type,
37    field_attrs: Vec<Attribute>,
38}
39
40impl FieldInfo {
41    /// Convert field info to a tuple.
42    fn into_tuple(self) -> (FieldRepr, Type, Vec<Attribute>) {
43        (self.field_name, self.field_type, self.field_attrs)
44    }
45
46    /// Convert a vector of [`FieldInfo`]s to a tuple of vectors
47    /// each containing name, type, or attributes.
48    pub fn to_vecs<I>(v: I) -> (Vec<FieldRepr>, Vec<Type>, Vec<Vec<Attribute>>)
49    where
50        I: Iterator<Item = Self>,
51    {
52        v.into_iter().fold(
53            (Vec::new(), Vec::new(), Vec::new()),
54            |(mut names, mut types, mut attrs), info| {
55                let (name, ty, attr) = info.into_tuple();
56                names.push(name);
57                types.push(ty);
58                attrs.push(attr);
59                (names, types, attrs)
60            },
61        )
62    }
63}
64
65/// Necessary information for a given struct to generate trait
66/// implementations.
67pub struct StructInfo {
68    struct_name: Ident,
69    generics: Generics,
70    generics_without_bounds: Generics,
71    field_info: Vec<FieldInfo>,
72    padded: bool,
73}
74
75type StructInfoTuple = (
76    Ident,
77    Generics,
78    Generics,
79    Vec<FieldRepr>,
80    Vec<Type>,
81    Vec<Vec<Attribute>>,
82    bool,
83);
84
85impl StructInfo {
86    /// Extract the necessary information from an
87    /// [`ItemStruct`][syn::ItemStruct] data structure.
88    pub fn from_item_struct(
89        i: ItemStruct,
90        trait_name: Option<&str>,
91        trait_bound_path: &str,
92        uses_self: bool,
93    ) -> Self {
94        let (mut generics, generics_without_bounds) = process_impl_generics(i.generics, trait_name);
95        let trait_bounds = process_trait_bounds(&i.attrs, trait_bound_path);
96        override_trait_bounds_on_generics(&mut generics, &trait_bounds);
97        let field_info = match i.fields {
98            Fields::Named(fields_named) => generate_named_fields(fields_named),
99            Fields::Unnamed(fields_unnamed) => generate_unnamed_fields(fields_unnamed, uses_self),
100            Fields::Unit => Vec::new(),
101        };
102        let padded = process_padding(&i.attrs);
103
104        StructInfo {
105            struct_name: i.ident,
106            generics,
107            generics_without_bounds,
108            field_info,
109            padded,
110        }
111    }
112
113    /// Remove the last field from the record.
114    pub fn pop_field(&mut self) {
115        let _ = self.field_info.pop();
116    }
117
118    /// Convert all necessary struct information into a tuple of
119    /// values.
120    pub fn into_tuple(mut self) -> StructInfoTuple {
121        let (field_names, field_types, field_attrs) = self.field_info();
122        (
123            self.struct_name,
124            self.generics,
125            self.generics_without_bounds,
126            field_names,
127            field_types,
128            field_attrs,
129            self.padded,
130        )
131    }
132
133    /// Convert all field information into a tuple.
134    fn field_info(&mut self) -> (Vec<FieldRepr>, Vec<Type>, Vec<Vec<Attribute>>) {
135        FieldInfo::to_vecs(self.field_info.drain(..))
136    }
137}
138
139/// Convert a list of identifiers into a path where the path segments
140/// are added in the order that they appear in the list.
141fn path_from_idents(idents: &[&str]) -> Path {
142    Path {
143        leading_colon: None,
144        segments: idents
145            .iter()
146            .map(|ident| PathSegment {
147                ident: Ident::new(ident, Span::call_site()),
148                arguments: PathArguments::None,
149            })
150            .collect::<Punctuated<PathSegment, Colon2>>(),
151    }
152}
153
154/// Process all type parameters in the type parameter definition for
155/// an `impl` block. Optionally add a trait bound for all type parameters
156/// if `required_trait` is `Some(_)`.
157///
158/// The first return value in the tuple is the list of type parameters
159/// with trait bounds added. The second argument is a list of type
160/// parameters without trait bounds to be passed into the type parameter
161/// list for a struct.
162///
163/// # Example:
164/// ## impl block
165///
166/// ```no_compile
167/// trait MyTrait {}
168///
169/// impl<T, P> MyStruct<T, P> {
170///     fn nothing() {}
171/// }
172/// ```
173///
174/// ## Method call
175/// `neli_proc_macros::process_impl_generics(generics, Some("MyTrait"))`
176///
177/// ## Result
178/// ```no_compile
179/// (<T: MyTrait, P: MyTrait>, <T, P>)
180/// ```
181///
182/// or rather:
183///
184/// ```no_compile
185/// impl<T: MyTrait, P: MyTrait> MyStruct<T, P> {
186///     fn nothing() {}
187/// }
188/// ```
189pub fn process_impl_generics(
190    mut generics: Generics,
191    required_trait: Option<&str>,
192) -> (Generics, Generics) {
193    if let Some(rt) = required_trait {
194        for gen in generics.params.iter_mut() {
195            if let GenericParam::Type(param) = gen {
196                param.colon_token = Some(Token![:](Span::call_site()));
197                param.bounds.push(TypeParamBound::Trait(TraitBound {
198                    paren_token: None,
199                    modifier: TraitBoundModifier::None,
200                    lifetimes: None,
201                    path: path_from_idents(&["neli", rt]),
202                }));
203                param.eq_token = None;
204                param.default = None;
205            }
206        }
207    }
208
209    let mut generics_without_bounds: Generics = generics.clone();
210    for gen in generics_without_bounds.params.iter_mut() {
211        if let GenericParam::Type(param) = gen {
212            param.colon_token = None;
213            param.bounds.clear();
214            param.eq_token = None;
215            param.default = None;
216        }
217    }
218
219    (generics, generics_without_bounds)
220}
221
222/// Remove attributes that should not be carried over to an `impl`
223/// definition and only belong in the data structure like documentation
224/// attributes.
225pub fn remove_bad_attrs(attrs: Vec<Attribute>) -> Vec<Attribute> {
226    attrs
227        .into_iter()
228        .filter(|attr| {
229            if let Ok(meta) = attr.parse_meta() {
230                match meta {
231                    Meta::NameValue(MetaNameValue { path, .. }) => {
232                        !(path == parse_str::<Path>("doc").expect("doc should be valid path"))
233                    }
234                    _ => true,
235                }
236            } else {
237                panic!("Could not parse provided attribute {}", attr.tokens,)
238            }
239        })
240        .collect()
241}
242
243/// Generate a pattern and associated expression for each variant
244/// in an enum.
245fn generate_pat_and_expr<N, U>(
246    enum_name: Ident,
247    var_name: Ident,
248    fields: Fields,
249    generate_named_pat_and_expr: &N,
250    generate_unnamed_pat_and_expr: &U,
251    unit: &TokenStream2,
252) -> TokenStream2
253where
254    N: Fn(Ident, Ident, FieldsNamed) -> TokenStream2,
255    U: Fn(Ident, Ident, FieldsUnnamed) -> TokenStream2,
256{
257    match fields {
258        Fields::Named(fields) => generate_named_pat_and_expr(enum_name, var_name, fields),
259        Fields::Unnamed(fields) => generate_unnamed_pat_and_expr(enum_name, var_name, fields),
260        Fields::Unit => quote! {
261            #enum_name::#var_name => #unit,
262        },
263    }
264}
265
266/// Convert an enum variant into an arm of a match statement.
267fn generate_arm<N, U>(
268    attrs: Vec<Attribute>,
269    enum_name: Ident,
270    var_name: Ident,
271    fields: Fields,
272    generate_named_pat_and_expr: &N,
273    generate_unnamed_pat_and_expr: &U,
274    unit: &TokenStream2,
275) -> TokenStream2
276where
277    N: Fn(Ident, Ident, FieldsNamed) -> TokenStream2,
278    U: Fn(Ident, Ident, FieldsUnnamed) -> TokenStream2,
279{
280    let attrs = remove_bad_attrs(attrs)
281        .into_iter()
282        .map(|attr| {
283            attr.parse_meta()
284                .unwrap_or_else(|_| panic!("Failed to parse attribute {}", attr.tokens))
285        })
286        .collect::<Vec<_>>();
287    let arm = generate_pat_and_expr(
288        enum_name,
289        var_name,
290        fields,
291        generate_named_pat_and_expr,
292        generate_unnamed_pat_and_expr,
293        unit,
294    );
295    quote! {
296        #(
297            #attrs
298        )*
299        #arm
300    }
301}
302
303/// Generate all arms of a match statement.
304pub fn generate_arms<N, U>(
305    enum_name: Ident,
306    variants: Vec<Variant>,
307    generate_named_pat_and_expr: N,
308    generate_unnamed_pat_and_expr: U,
309    unit: TokenStream2,
310) -> Vec<TokenStream2>
311where
312    N: Fn(Ident, Ident, FieldsNamed) -> TokenStream2,
313    U: Fn(Ident, Ident, FieldsUnnamed) -> TokenStream2,
314{
315    variants
316        .into_iter()
317        .map(|var| {
318            let variant_name = var.ident;
319            generate_arm(
320                var.attrs,
321                enum_name.clone(),
322                variant_name,
323                var.fields,
324                &generate_named_pat_and_expr,
325                &generate_unnamed_pat_and_expr,
326                &unit,
327            )
328        })
329        .collect()
330}
331
332/// Generate a list of named fields in accordance with the struct.
333pub fn generate_named_fields(fields: FieldsNamed) -> Vec<FieldInfo> {
334    fields
335        .named
336        .into_iter()
337        .fold(Vec::new(), |mut info, field| {
338            info.push(FieldInfo {
339                field_name: FieldRepr::Ident(field.ident.expect("Must be named")),
340                field_type: field.ty,
341                field_attrs: field.attrs,
342            });
343            info
344        })
345}
346
347/// Generate unnamed fields as either indicies to be accessed using
348/// `self` or placeholder variable names for match-style patterns.
349pub fn generate_unnamed_fields(fields: FieldsUnnamed, uses_self: bool) -> Vec<FieldInfo> {
350    fields
351        .unnamed
352        .into_iter()
353        .enumerate()
354        .fold(Vec::new(), |mut fields, (index, field)| {
355            fields.push(FieldInfo {
356                field_name: if uses_self {
357                    FieldRepr::Index(Index {
358                        index: index as u32,
359                        span: Span::call_site(),
360                    })
361                } else {
362                    FieldRepr::Ident(Ident::new(
363                        &String::from((b'a' + index as u8) as char),
364                        Span::call_site(),
365                    ))
366                },
367                field_type: field.ty,
368                field_attrs: field.attrs,
369            });
370            fields
371        })
372}
373
374/// Returns [`true`] if the given attribute is present in the list.
375fn attr_present(attrs: &[Attribute], attr_name: &str) -> bool {
376    for attr in attrs {
377        let meta = attr
378            .parse_meta()
379            .unwrap_or_else(|_| panic!("Failed to parse attribute {}", attr.tokens));
380        if let Meta::List(list) = meta {
381            if list.path == parse_str::<Path>("neli").expect("neli is valid path") {
382                for nested in list.nested {
383                    if let NestedMeta::Meta(Meta::Path(path)) = nested {
384                        if path
385                            == parse_str::<Path>(attr_name)
386                                .unwrap_or_else(|_| panic!("{} should be valid path", attr_name))
387                        {
388                            return true;
389                        }
390                    }
391                }
392            }
393        }
394    }
395    false
396}
397
398/// Process attributes to find all attributes with the name `attr_name`.
399/// Return a [`Vec`] of [`Option`] types with the associated literal parsed
400/// into type parameter `T`. `T` must allow parsing from a string to be
401/// used with this method.
402fn process_attr<T>(attrs: &[Attribute], attr_name: &str) -> Vec<Option<T>>
403where
404    T: Parse,
405{
406    let mut output = Vec::new();
407    for attr in attrs {
408        let meta = attr
409            .parse_meta()
410            .unwrap_or_else(|_| panic!("Failed to parse attribute {}", attr.tokens));
411        if let Meta::List(list) = meta {
412            if list.path == parse_str::<Path>("neli").expect("neli is valid path") {
413                for nested in list.nested {
414                    if let NestedMeta::Meta(Meta::NameValue(MetaNameValue {
415                        path,
416                        lit: Lit::Str(lit),
417                        ..
418                    })) = nested
419                    {
420                        if path
421                            == parse_str::<Path>(attr_name)
422                                .unwrap_or_else(|_| panic!("{} should be valid path", attr_name))
423                        {
424                            output.push(Some(parse_str::<T>(&lit.value()).unwrap_or_else(|_| {
425                                panic!(
426                                    "{} should be valid tokens of type {}",
427                                    &lit.value(),
428                                    type_name::<T>()
429                                )
430                            })));
431                        }
432                    } else if let NestedMeta::Meta(Meta::Path(path)) = nested {
433                        if path
434                            == parse_str::<Path>(attr_name)
435                                .unwrap_or_else(|_| panic!("{} should be valid path", attr_name))
436                        {
437                            output.push(None);
438                        }
439                    }
440                }
441            }
442        }
443    }
444    output
445}
446
447pub fn process_trait_bounds(attrs: &[Attribute], trait_bound_path: &str) -> Vec<TypeParam> {
448    process_attr(attrs, trait_bound_path)
449        .into_iter()
450        .flatten()
451        .collect()
452}
453
454/// Handles the attribute `#[neli(padding)]`.
455pub fn process_padding(attrs: &[Attribute]) -> bool {
456    attr_present(attrs, "padding")
457}
458
459/// Handles the attribute `#[neli(input)]` or `#[neli(input = "...")]`
460/// when deriving [`FromBytes`][neli::FromBytes] implementations.
461///
462/// Returns:
463/// * [`None`] if the attribute is not present
464/// * [`Some(None)`] if the attribute is present and has no
465/// associated expression
466/// * [`Some(Some(_))`] if the attribute is present and
467/// has an associated expression
468pub fn process_input(attrs: &[Attribute]) -> Option<Option<Expr>> {
469    let mut exprs = process_attr(attrs, "input");
470    if exprs.len() > 1 {
471        panic!("Only one input expression allowed for attribute #[neli(input = \"...\")]");
472    } else {
473        exprs.pop()
474    }
475}
476
477/// Handles the attribute `#[neli(size = "...")]`
478/// when deriving [`FromBytes`][neli::FromBytes] implementations.
479///
480/// Returns:
481/// * [`None`] if the attribute is not present
482/// associated expression
483/// * [`Some(_)`] if the attribute is present and has an associated expression
484pub fn process_size(attrs: &[Attribute]) -> Option<Expr> {
485    let mut exprs = process_attr(attrs, "size");
486    if exprs.len() > 1 {
487        panic!("Only one input expression allowed for attribute #[neli(size = \"...\")]");
488    } else {
489        exprs
490            .pop()
491            .map(|opt| opt.expect("#[neli(size = \"...\")] must have associated expression"))
492    }
493}
494
495/// If the first type parameter of a list of type parameters is a lifetime,
496/// extract it for use in other parts of the procedural macro code.
497///
498/// # Example
499/// `impl<'a, I, P>` would return `'a`.
500pub fn process_lifetime(generics: &mut Generics) -> LifetimeDef {
501    if let Some(GenericParam::Lifetime(lt)) = generics.params.first() {
502        lt.clone()
503    } else {
504        let mut punc = Punctuated::new();
505        let lt = parse::<LifetimeDef>(TokenStream::from(quote! {
506            'lifetime
507        }))
508        .expect("'lifetime should be valid lifetime");
509        punc.push(GenericParam::Lifetime(lt.clone()));
510        punc.push_punct(Token![,](Span::call_site()));
511        punc.extend(generics.params.iter().cloned());
512        generics.params = punc;
513        lt
514    }
515}
516
517/// Allow overriding the trait bounds specified by the method
518/// [`process_impl_generics`][process_impl_generics].
519///
520/// # Example
521/// ```no_compile
522/// use std::marker::PhantomData;
523///
524/// struct MyStruct<I, A>(PhantomData<I>, PhantomData<A>);
525///
526/// trait MyTrait {}
527/// trait AnotherTrait {}
528///
529/// // Input
530///
531/// impl<I: MyTrait, A: MyTrait> MyStruct<I, A> {
532///     fn nothing() {}
533/// }
534///
535/// // Result
536///
537/// impl<I: AnotherTrait, A: MyTrait> MyStruct<I, A> {
538///     fn nothing() {}
539/// }
540/// ```
541fn override_trait_bounds_on_generics(generics: &mut Generics, trait_bound_overrides: &[TypeParam]) {
542    let mut overrides = trait_bound_overrides.iter().cloned().fold(
543        HashMap::<Ident, Punctuated<TypeParamBound, Add>>::new(),
544        |mut map, param| {
545            if let Some(bounds) = map.get_mut(&param.ident) {
546                bounds.extend(param.bounds);
547            } else {
548                map.insert(param.ident, param.bounds);
549            }
550            map
551        },
552    );
553
554    for generic in generics.params.iter_mut() {
555        if let GenericParam::Type(ref mut ty) = generic {
556            let ident = &ty.ident;
557            if let Some(ors) = overrides.remove(ident) {
558                ty.colon_token = Some(Token![:](Span::call_site()));
559                ty.bounds = ors;
560                ty.eq_token = None;
561                ty.default = None;
562            }
563        }
564    }
565}