merge_derive/
lib.rs

1// SPDX-FileCopyrightText: 2020 Robin Krahl <robin.krahl@ireas.org>
2// SPDX-License-Identifier: Apache-2.0 or MIT
3
4//! A derive macro for the [`merge::Merge`][] trait.
5//!
6//! See the documentation for the [`merge`][] crate for more information.
7//!
8//! [`merge`]: https://lib.rs/crates/merge
9//! [`merge::Merge`]: https://docs.rs/merge/latest/merge/trait.Merge.html
10
11extern crate proc_macro;
12
13use proc_macro2::TokenStream;
14use proc_macro_error::{abort, abort_call_site, dummy::set_dummy, proc_macro_error, ResultExt};
15use quote::{quote, quote_spanned};
16use syn::Token;
17
18struct Field {
19    name: syn::Member,
20    span: proc_macro2::Span,
21    attrs: FieldAttrs,
22}
23
24#[derive(Default)]
25struct FieldAttrs {
26    skip: bool,
27    strategy: Option<syn::Path>,
28}
29
30enum FieldAttr {
31    Skip,
32    Strategy(syn::Path),
33}
34
35#[proc_macro_derive(Merge, attributes(merge))]
36#[proc_macro_error]
37pub fn merge_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
38    let ast = syn::parse(input).unwrap();
39    impl_merge(&ast).into()
40}
41
42fn impl_merge(ast: &syn::DeriveInput) -> TokenStream {
43    let name = &ast.ident;
44
45    set_dummy(quote! {
46        impl ::merge::Merge for #name {
47            fn merge(&mut self, other: Self) {
48                unimplemented!()
49            }
50        }
51    });
52
53    if let syn::Data::Struct(syn::DataStruct { ref fields, .. }) = ast.data {
54        impl_merge_for_struct(name, fields)
55    } else {
56        abort_call_site!("merge::Merge can only be derived for structs")
57    }
58}
59
60fn impl_merge_for_struct(name: &syn::Ident, fields: &syn::Fields) -> TokenStream {
61    let assignments = gen_assignments(fields);
62
63    quote! {
64        impl ::merge::Merge for #name {
65            fn merge(&mut self, other: Self) {
66                #assignments
67            }
68        }
69    }
70}
71
72fn gen_assignments(fields: &syn::Fields) -> TokenStream {
73    let fields = fields.iter().enumerate().map(Field::from);
74    let assignments = fields.filter(|f| !f.attrs.skip).map(|f| gen_assignment(&f));
75    quote! {
76        #( #assignments )*
77    }
78}
79
80fn gen_assignment(field: &Field) -> TokenStream {
81    use syn::spanned::Spanned;
82
83    let name = &field.name;
84    if let Some(strategy) = &field.attrs.strategy {
85        quote_spanned!(strategy.span()=> #strategy(&mut self.#name, other.#name);)
86    } else {
87        quote_spanned!(field.span=> ::merge::Merge::merge(&mut self.#name, other.#name);)
88    }
89}
90
91impl From<(usize, &syn::Field)> for Field {
92    fn from(data: (usize, &syn::Field)) -> Self {
93        use syn::spanned::Spanned;
94
95        let (index, field) = data;
96        Field {
97            name: if let Some(ident) = &field.ident {
98                syn::Member::Named(ident.clone())
99            } else {
100                syn::Member::Unnamed(index.into())
101            },
102            span: field.span(),
103            attrs: field.attrs.iter().into(),
104        }
105    }
106}
107
108impl FieldAttrs {
109    fn apply(&mut self, attr: FieldAttr) {
110        match attr {
111            FieldAttr::Skip => self.skip = true,
112            FieldAttr::Strategy(path) => self.strategy = Some(path),
113        }
114    }
115}
116
117impl<'a, I: Iterator<Item = &'a syn::Attribute>> From<I> for FieldAttrs {
118    fn from(iter: I) -> Self {
119        let mut field_attrs = Self::default();
120
121        for attr in iter {
122            if !attr.path.is_ident("merge") {
123                continue;
124            }
125
126            let parser = syn::punctuated::Punctuated::<FieldAttr, Token![,]>::parse_terminated;
127            for attr in attr.parse_args_with(parser).unwrap_or_abort() {
128                field_attrs.apply(attr);
129            }
130        }
131
132        field_attrs
133    }
134}
135
136impl syn::parse::Parse for FieldAttr {
137    fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
138        let name: syn::Ident = input.parse()?;
139        if name == "skip" {
140            // TODO check remaining stream
141            Ok(FieldAttr::Skip)
142        } else if name == "strategy" {
143            let _: Token![=] = input.parse()?;
144            let path: syn::Path = input.parse()?;
145            Ok(FieldAttr::Strategy(path))
146        } else {
147            abort!(name, "Unexpected attribute: {}", name)
148        }
149    }
150}