merge_derive/
lib.rs

1// SPDX-FileCopyrightText: 2020 Robin Krahl <robin.krahl@ireas.org>
2// SPDX-License-Identifier: Apache-2.0 or MIT
3
4//! A derive macro for the [`merge::Merge`][] trait.
5//!
6//! See the documentation for the [`merge`][] crate for more information.
7//!
8//! [`merge`]: https://lib.rs/crates/merge
9//! [`merge::Merge`]: https://docs.rs/merge/latest/merge/trait.Merge.html
10
11extern crate proc_macro;
12
13use proc_macro2::TokenStream;
14use proc_macro_error2::{abort, abort_call_site, dummy::set_dummy, proc_macro_error, ResultExt};
15use quote::{quote, quote_spanned};
16use syn::Token;
17
18struct Field {
19    name: syn::Member,
20    span: proc_macro2::Span,
21    attrs: FieldAttrs,
22}
23
24#[derive(Default)]
25struct FieldAttrs {
26    skip: bool,
27    strategy: Option<syn::Path>,
28}
29
30enum FieldAttr {
31    Skip,
32    Strategy(syn::Path),
33}
34
35#[proc_macro_derive(Merge, attributes(merge))]
36#[proc_macro_error]
37pub fn merge_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
38    let ast = syn::parse(input).unwrap();
39    impl_merge(&ast).into()
40}
41
42fn impl_merge(ast: &syn::DeriveInput) -> TokenStream {
43    let name = &ast.ident;
44    let default_strategy = FieldAttrs::from(ast.attrs.iter());
45
46    set_dummy(quote! {
47        impl ::merge::Merge for #name {
48            fn merge(&mut self, other: Self) {
49                unimplemented!()
50            }
51        }
52    });
53
54    if let syn::Data::Struct(syn::DataStruct { ref fields, .. }) = ast.data {
55        impl_merge_for_struct(name, fields, default_strategy)
56    } else {
57        abort_call_site!("merge::Merge can only be derived for structs")
58    }
59}
60
61fn impl_merge_for_struct(
62    name: &syn::Ident,
63    fields: &syn::Fields,
64    default_strategy: FieldAttrs,
65) -> TokenStream {
66    let assignments = gen_assignments(fields, default_strategy);
67
68    quote! {
69        impl ::merge::Merge for #name {
70            fn merge(&mut self, other: Self) {
71                #assignments
72            }
73        }
74    }
75}
76
77fn gen_assignments(fields: &syn::Fields, default_strategy: FieldAttrs) -> TokenStream {
78    let fields = fields.iter().enumerate().map(Field::from);
79    let assignments = fields
80        .filter(|f| !f.attrs.skip)
81        .map(|f| gen_assignment(&f, &default_strategy));
82    quote! {
83        #( #assignments )*
84    }
85}
86
87fn gen_assignment(field: &Field, default_strategy: &FieldAttrs) -> TokenStream {
88    use syn::spanned::Spanned;
89
90    let name = &field.name;
91    if let Some(strategy) = &field.attrs.strategy {
92        quote_spanned!(strategy.span()=> #strategy(&mut self.#name, other.#name);)
93    } else if let Some(default) = &default_strategy.strategy {
94        quote_spanned!(default.span()=> #default(&mut self.#name, other.#name);)
95    } else {
96        quote_spanned!(field.span=> ::merge::Merge::merge(&mut self.#name, other.#name);)
97    }
98}
99
100impl From<(usize, &syn::Field)> for Field {
101    fn from(data: (usize, &syn::Field)) -> Self {
102        use syn::spanned::Spanned;
103
104        let (index, field) = data;
105        Field {
106            name: if let Some(ident) = &field.ident {
107                syn::Member::Named(ident.clone())
108            } else {
109                syn::Member::Unnamed(index.into())
110            },
111            span: field.span(),
112            attrs: field.attrs.iter().into(),
113        }
114    }
115}
116
117impl FieldAttrs {
118    fn apply(&mut self, attr: FieldAttr) {
119        match attr {
120            FieldAttr::Skip => self.skip = true,
121            FieldAttr::Strategy(path) => self.strategy = Some(path),
122        }
123    }
124}
125
126impl<'a, I: Iterator<Item = &'a syn::Attribute>> From<I> for FieldAttrs {
127    fn from(iter: I) -> Self {
128        let mut field_attrs = Self::default();
129
130        for attr in iter {
131            if !attr.path().is_ident("merge") {
132                continue;
133            }
134
135            let parser = syn::punctuated::Punctuated::<FieldAttr, Token![,]>::parse_terminated;
136            for attr in attr.parse_args_with(parser).unwrap_or_abort() {
137                field_attrs.apply(attr);
138            }
139        }
140
141        field_attrs
142    }
143}
144
145impl syn::parse::Parse for FieldAttr {
146    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
147        let name: syn::Ident = input.parse()?;
148        if name == "skip" {
149            // TODO check remaining stream
150            Ok(FieldAttr::Skip)
151        } else if name == "strategy" {
152            let _: Token![=] = input.parse()?;
153            let path: syn::Path = input.parse()?;
154            Ok(FieldAttr::Strategy(path))
155        } else {
156            abort!(name, "Unexpected attribute: {}", name)
157        }
158    }
159}