neli_proc_macros/
shared.rs

1use std::{any::type_name, collections::HashMap};
2
3use proc_macro2::{Span, TokenStream as TokenStream2, TokenTree};
4use quote::{ToTokens, quote};
5use syn::{
6    Attribute, Expr, Fields, FieldsNamed, FieldsUnnamed, GenericParam, Generics, Ident, Index,
7    ItemStruct, LitStr, Meta, MetaNameValue, Path, PathArguments, PathSegment, Token, TraitBound,
8    TraitBoundModifier, Type, TypeParam, TypeParamBound, Variant,
9    parse::Parse,
10    parse_str,
11    punctuated::Punctuated,
12    token::{PathSep, Plus},
13};
14
15/// Represents a field as either an identifier or an index.
16pub enum FieldRepr {
17    Index(Index),
18    Ident(Ident),
19}
20
21impl ToTokens for FieldRepr {
22    fn to_tokens(&self, tokens: &mut TokenStream2) {
23        match self {
24            FieldRepr::Index(i) => i.to_tokens(tokens),
25            FieldRepr::Ident(i) => i.to_tokens(tokens),
26        }
27    }
28}
29
30/// Represents the field name, type, and all attributes associated
31/// with this field.
32pub struct FieldInfo {
33    field_name: FieldRepr,
34    field_type: Type,
35    field_attrs: Vec<Attribute>,
36}
37
38impl FieldInfo {
39    /// Convert field info to a tuple.
40    fn into_tuple(self) -> (FieldRepr, Type, Vec<Attribute>) {
41        (self.field_name, self.field_type, self.field_attrs)
42    }
43
44    /// Convert a vector of [`FieldInfo`]s to a tuple of vectors
45    /// each containing name, type, or attributes.
46    pub fn to_vecs<I>(v: I) -> (Vec<FieldRepr>, Vec<Type>, Vec<Vec<Attribute>>)
47    where
48        I: Iterator<Item = Self>,
49    {
50        v.into_iter().fold(
51            (Vec::new(), Vec::new(), Vec::new()),
52            |(mut names, mut types, mut attrs), info| {
53                let (name, ty, attr) = info.into_tuple();
54                names.push(name);
55                types.push(ty);
56                attrs.push(attr);
57                (names, types, attrs)
58            },
59        )
60    }
61}
62
63/// Necessary information for a given struct to generate trait
64/// implementations.
65pub struct StructInfo {
66    struct_name: Ident,
67    generics: Generics,
68    generics_without_bounds: Generics,
69    field_info: Vec<FieldInfo>,
70    padded: bool,
71}
72
73type StructInfoTuple = (
74    Ident,
75    Generics,
76    Generics,
77    Vec<FieldRepr>,
78    Vec<Type>,
79    Vec<Vec<Attribute>>,
80    bool,
81);
82
83impl StructInfo {
84    /// Extract the necessary information from an
85    /// [`ItemStruct`][syn::ItemStruct] data structure.
86    pub fn from_item_struct(
87        i: ItemStruct,
88        trait_name: Option<&str>,
89        trait_bound_path: &str,
90        uses_self: bool,
91    ) -> Self {
92        let (mut generics, generics_without_bounds) = process_impl_generics(i.generics, trait_name);
93        let trait_bounds = process_trait_bounds(&i.attrs, trait_bound_path);
94        override_trait_bounds_on_generics(&mut generics, &trait_bounds);
95        let field_info = match i.fields {
96            Fields::Named(fields_named) => generate_named_fields(fields_named),
97            Fields::Unnamed(fields_unnamed) => generate_unnamed_fields(fields_unnamed, uses_self),
98            Fields::Unit => Vec::new(),
99        };
100        let padded = process_padding(&i.attrs);
101
102        StructInfo {
103            struct_name: i.ident,
104            generics,
105            generics_without_bounds,
106            field_info,
107            padded,
108        }
109    }
110
111    /// Remove the last field from the record.
112    pub fn pop_field(&mut self) {
113        let _ = self.field_info.pop();
114    }
115
116    /// Convert all necessary struct information into a tuple of
117    /// values.
118    pub fn into_tuple(mut self) -> StructInfoTuple {
119        let (field_names, field_types, field_attrs) = self.field_info();
120        (
121            self.struct_name,
122            self.generics,
123            self.generics_without_bounds,
124            field_names,
125            field_types,
126            field_attrs,
127            self.padded,
128        )
129    }
130
131    /// Convert all field information into a tuple.
132    fn field_info(&mut self) -> (Vec<FieldRepr>, Vec<Type>, Vec<Vec<Attribute>>) {
133        FieldInfo::to_vecs(self.field_info.drain(..))
134    }
135}
136
137/// Convert a list of identifiers into a path where the path segments
138/// are added in the order that they appear in the list.
139fn path_from_idents(idents: &[&str]) -> Path {
140    Path {
141        leading_colon: None,
142        segments: idents
143            .iter()
144            .map(|ident| PathSegment {
145                ident: Ident::new(ident, Span::call_site()),
146                arguments: PathArguments::None,
147            })
148            .collect::<Punctuated<PathSegment, PathSep>>(),
149    }
150}
151
152/// Process all type parameters in the type parameter definition for
153/// an `impl` block. Optionally add a trait bound for all type parameters
154/// if `required_trait` is `Some(_)`.
155///
156/// The first return value in the tuple is the list of type parameters
157/// with trait bounds added. The second argument is a list of type
158/// parameters without trait bounds to be passed into the type parameter
159/// list for a struct.
160///
161/// # Example:
162/// ## impl block
163///
164/// ```no_compile
165/// trait MyTrait {}
166///
167/// impl<T, P> MyStruct<T, P> {
168///     fn nothing() {}
169/// }
170/// ```
171///
172/// ## Method call
173/// `neli_proc_macros::process_impl_generics(generics, Some("MyTrait"))`
174///
175/// ## Result
176/// ```no_compile
177/// (<T: MyTrait, P: MyTrait>, <T, P>)
178/// ```
179///
180/// or rather:
181///
182/// ```no_compile
183/// impl<T: MyTrait, P: MyTrait> MyStruct<T, P> {
184///     fn nothing() {}
185/// }
186/// ```
187pub fn process_impl_generics(
188    mut generics: Generics,
189    required_trait: Option<&str>,
190) -> (Generics, Generics) {
191    if let Some(rt) = required_trait {
192        for generic in generics.params.iter_mut() {
193            if let GenericParam::Type(param) = generic {
194                param.colon_token = Some(Token![:](Span::call_site()));
195                param.bounds.push(TypeParamBound::Trait(TraitBound {
196                    paren_token: None,
197                    modifier: TraitBoundModifier::None,
198                    lifetimes: None,
199                    path: path_from_idents(&["neli", rt]),
200                }));
201                param.eq_token = None;
202                param.default = None;
203            }
204        }
205    }
206
207    let mut generics_without_bounds: Generics = generics.clone();
208    for generic in generics_without_bounds.params.iter_mut() {
209        if let GenericParam::Type(param) = generic {
210            param.colon_token = None;
211            param.bounds.clear();
212            param.eq_token = None;
213            param.default = None;
214        }
215    }
216
217    (generics, generics_without_bounds)
218}
219
220/// Remove attributes that should not be carried over to an `impl`
221/// definition and only belong in the data structure like documentation
222/// attributes.
223pub fn remove_bad_attrs(attrs: Vec<Attribute>) -> Vec<Attribute> {
224    attrs
225        .into_iter()
226        .filter(|attr| match &attr.meta {
227            Meta::NameValue(MetaNameValue { path, .. }) => !path.is_ident("doc"),
228            _ => true,
229        })
230        .collect()
231}
232
233/// Generate a pattern and associated expression for each variant
234/// in an enum.
235fn generate_pat_and_expr<N, U>(
236    enum_name: Ident,
237    var_name: Ident,
238    fields: Fields,
239    generate_named_pat_and_expr: &N,
240    generate_unnamed_pat_and_expr: &U,
241    unit: &TokenStream2,
242) -> TokenStream2
243where
244    N: Fn(Ident, Ident, FieldsNamed) -> TokenStream2,
245    U: Fn(Ident, Ident, FieldsUnnamed) -> TokenStream2,
246{
247    match fields {
248        Fields::Named(fields) => generate_named_pat_and_expr(enum_name, var_name, fields),
249        Fields::Unnamed(fields) => generate_unnamed_pat_and_expr(enum_name, var_name, fields),
250        Fields::Unit => quote! {
251            #enum_name::#var_name => #unit,
252        },
253    }
254}
255
256/// Convert an enum variant into an arm of a match statement.
257fn generate_arm<N, U>(
258    attrs: Vec<Attribute>,
259    enum_name: Ident,
260    var_name: Ident,
261    fields: Fields,
262    generate_named_pat_and_expr: &N,
263    generate_unnamed_pat_and_expr: &U,
264    unit: &TokenStream2,
265) -> TokenStream2
266where
267    N: Fn(Ident, Ident, FieldsNamed) -> TokenStream2,
268    U: Fn(Ident, Ident, FieldsUnnamed) -> TokenStream2,
269{
270    let attrs = remove_bad_attrs(attrs)
271        .into_iter()
272        .map(|attr| attr.meta)
273        .collect::<Vec<_>>();
274    let arm = generate_pat_and_expr(
275        enum_name,
276        var_name,
277        fields,
278        generate_named_pat_and_expr,
279        generate_unnamed_pat_and_expr,
280        unit,
281    );
282    quote! {
283        #(
284            #attrs
285        )*
286        #arm
287    }
288}
289
290/// Generate all arms of a match statement.
291pub fn generate_arms<N, U>(
292    enum_name: Ident,
293    variants: Vec<Variant>,
294    generate_named_pat_and_expr: N,
295    generate_unnamed_pat_and_expr: U,
296    unit: TokenStream2,
297) -> Vec<TokenStream2>
298where
299    N: Fn(Ident, Ident, FieldsNamed) -> TokenStream2,
300    U: Fn(Ident, Ident, FieldsUnnamed) -> TokenStream2,
301{
302    variants
303        .into_iter()
304        .map(|var| {
305            let variant_name = var.ident;
306            generate_arm(
307                var.attrs,
308                enum_name.clone(),
309                variant_name,
310                var.fields,
311                &generate_named_pat_and_expr,
312                &generate_unnamed_pat_and_expr,
313                &unit,
314            )
315        })
316        .collect()
317}
318
319/// Generate a list of named fields in accordance with the struct.
320pub fn generate_named_fields(fields: FieldsNamed) -> Vec<FieldInfo> {
321    fields
322        .named
323        .into_iter()
324        .fold(Vec::new(), |mut info, field| {
325            info.push(FieldInfo {
326                field_name: FieldRepr::Ident(field.ident.expect("Must be named")),
327                field_type: field.ty,
328                field_attrs: field.attrs,
329            });
330            info
331        })
332}
333
334/// Generate unnamed fields as either indicies to be accessed using
335/// `self` or placeholder variable names for match-style patterns.
336pub fn generate_unnamed_fields(fields: FieldsUnnamed, uses_self: bool) -> Vec<FieldInfo> {
337    fields
338        .unnamed
339        .into_iter()
340        .enumerate()
341        .fold(Vec::new(), |mut fields, (index, field)| {
342            fields.push(FieldInfo {
343                field_name: if uses_self {
344                    FieldRepr::Index(Index {
345                        index: index as u32,
346                        span: Span::call_site(),
347                    })
348                } else {
349                    FieldRepr::Ident(Ident::new(
350                        &String::from((b'a' + index as u8) as char),
351                        Span::call_site(),
352                    ))
353                },
354                field_type: field.ty,
355                field_attrs: field.attrs,
356            });
357            fields
358        })
359}
360
361/// Returns [`true`] if the given attribute is present in the list.
362fn attr_present(attrs: &[Attribute], attr_name: &str) -> bool {
363    for attr in attrs {
364        if let Meta::List(list) = &attr.meta {
365            if list.path.is_ident("neli") {
366                for token in list.tokens.clone() {
367                    if let TokenTree::Ident(ident) = token {
368                        if ident == attr_name {
369                            return true;
370                        }
371                    }
372                }
373            }
374        }
375    }
376    false
377}
378
379/// Process attributes to find all attributes with the name `attr_name`.
380/// Return a [`Vec`] of [`Option`] types with the associated literal parsed
381/// into type parameter `T`. `T` must allow parsing from a string to be
382/// used with this method.
383fn process_attr<T>(attrs: &[Attribute], attr_name: &str) -> Vec<Option<T>>
384where
385    T: Parse,
386{
387    let mut output = Vec::new();
388    for attr in attrs {
389        if attr.path().is_ident("neli") {
390            attr.parse_nested_meta(|meta| {
391                let literal_str = match meta.value() {
392                    Ok(value) => match value.parse::<LitStr>() {
393                        Ok(v) => Some(v.value()),
394                        Err(_) => panic!("Cannot have a bare ="),
395                    },
396                    Err(_) => None,
397                };
398                if meta.path.is_ident(attr_name) {
399                    match literal_str {
400                        Some(l) => {
401                            output.push(Some(parse_str::<T>(&l).unwrap_or_else(|_| {
402                                panic!("{} should be valid tokens of type {}", l, type_name::<T>())
403                            })));
404                        }
405                        None => {
406                            output.push(None);
407                        }
408                    }
409                }
410                Ok(())
411            })
412            .unwrap_or_else(|e| {
413                panic!(
414                    "{}",
415                    format!("Should be able to parse all nested attributes: {e}")
416                )
417            });
418        }
419    }
420    output
421}
422
423pub fn process_trait_bounds(attrs: &[Attribute], trait_bound_path: &str) -> Vec<TypeParam> {
424    process_attr(attrs, trait_bound_path)
425        .into_iter()
426        .flatten()
427        .collect()
428}
429
430/// Handles the attribute `#[neli(padding)]`.
431pub fn process_padding(attrs: &[Attribute]) -> bool {
432    attr_present(attrs, "padding")
433}
434
435/// Handles the attribute `#[neli(input)]` or `#[neli(input = "...")]`
436/// when deriving [`FromBytes`][neli::FromBytes] implementations.
437///
438/// Returns:
439/// * [`None`] if the attribute is not present
440/// * [`Some(None)`] if the attribute is present and has no
441///   associated expression
442/// * [`Some(Some(_))`] if the attribute is present and
443///   has an associated expression
444pub fn process_input(attrs: &[Attribute]) -> Option<Option<Expr>> {
445    let mut exprs = process_attr(attrs, "input");
446    if exprs.len() > 1 {
447        panic!("Only one instance of the attribute allowed for attribute #[neli(input = \"...\")]");
448    } else {
449        exprs.pop()
450    }
451}
452
453/// Handles the attribute `#[neli(skip_debug)]`
454/// when deriving [`FromBytes`][neli::FromBytes] implementations.
455/// This removes the restriction for the field to have [`TypeSize`][neli::TypeSize]
456/// implemented by skipping buffer trace logging for this field.
457///
458/// Returns:
459/// * [`false`] if the attribute is not present
460/// * [`true`] if the attribute is present
461pub fn process_skip_debug(attrs: &[Attribute]) -> bool {
462    let exprs = process_attr::<Expr>(attrs, "skip_debug");
463    if exprs.is_empty() {
464        false
465    } else if exprs.iter().any(|expr| expr.is_some()) {
466        panic!("No input expressions allowed for #[neli(skip_debug)]")
467    } else {
468        true
469    }
470}
471
472/// Handles the attribute `#[neli(size = "...")]`
473/// when deriving [`FromBytes`][neli::FromBytes] implementations.
474///
475/// Returns:
476/// * [`None`] if the attribute is not present
477///   associated expression
478/// * [`Some(_)`] if the attribute is present and has an associated expression
479pub fn process_size(attrs: &[Attribute]) -> Option<Expr> {
480    let mut exprs = process_attr(attrs, "size");
481    if exprs.len() > 1 {
482        panic!("Only one input expression allowed for attribute #[neli(size = \"...\")]");
483    } else {
484        exprs
485            .pop()
486            .map(|opt| opt.expect("#[neli(size = \"...\")] must have associated expression"))
487    }
488}
489
490/// Allow overriding the trait bounds specified by the method
491/// [`process_impl_generics`][process_impl_generics].
492///
493/// # Example
494/// ```no_compile
495/// use std::marker::PhantomData;
496///
497/// struct MyStruct<I, A>(PhantomData<I>, PhantomData<A>);
498///
499/// trait MyTrait {}
500/// trait AnotherTrait {}
501///
502/// // Input
503///
504/// impl<I: MyTrait, A: MyTrait> MyStruct<I, A> {
505///     fn nothing() {}
506/// }
507///
508/// // Result
509///
510/// impl<I: AnotherTrait, A: MyTrait> MyStruct<I, A> {
511///     fn nothing() {}
512/// }
513/// ```
514fn override_trait_bounds_on_generics(generics: &mut Generics, trait_bound_overrides: &[TypeParam]) {
515    let mut overrides = trait_bound_overrides.iter().cloned().fold(
516        HashMap::<Ident, Punctuated<TypeParamBound, Plus>>::new(),
517        |mut map, param| {
518            if let Some(bounds) = map.get_mut(&param.ident) {
519                bounds.extend(param.bounds);
520            } else {
521                map.insert(param.ident, param.bounds);
522            }
523            map
524        },
525    );
526
527    for generic in generics.params.iter_mut() {
528        if let GenericParam::Type(ty) = generic {
529            let ident = &ty.ident;
530            if let Some(ors) = overrides.remove(ident) {
531                ty.colon_token = Some(Token![:](Span::call_site()));
532                ty.bounds = ors;
533                ty.eq_token = None;
534                ty.default = None;
535            }
536        }
537    }
538}