merge_derive/
lib.rs
1extern crate proc_macro;
12
13use proc_macro2::TokenStream;
14use proc_macro_error::{abort, abort_call_site, dummy::set_dummy, proc_macro_error, ResultExt};
15use quote::{quote, quote_spanned};
16use syn::Token;
17
18struct Field {
19 name: syn::Member,
20 span: proc_macro2::Span,
21 attrs: FieldAttrs,
22}
23
24#[derive(Default)]
25struct FieldAttrs {
26 skip: bool,
27 strategy: Option<syn::Path>,
28}
29
30enum FieldAttr {
31 Skip,
32 Strategy(syn::Path),
33}
34
35#[proc_macro_derive(Merge, attributes(merge))]
36#[proc_macro_error]
37pub fn merge_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
38 let ast = syn::parse(input).unwrap();
39 impl_merge(&ast).into()
40}
41
42fn impl_merge(ast: &syn::DeriveInput) -> TokenStream {
43 let name = &ast.ident;
44
45 set_dummy(quote! {
46 impl ::merge::Merge for #name {
47 fn merge(&mut self, other: Self) {
48 unimplemented!()
49 }
50 }
51 });
52
53 if let syn::Data::Struct(syn::DataStruct { ref fields, .. }) = ast.data {
54 impl_merge_for_struct(name, fields)
55 } else {
56 abort_call_site!("merge::Merge can only be derived for structs")
57 }
58}
59
60fn impl_merge_for_struct(name: &syn::Ident, fields: &syn::Fields) -> TokenStream {
61 let assignments = gen_assignments(fields);
62
63 quote! {
64 impl ::merge::Merge for #name {
65 fn merge(&mut self, other: Self) {
66 #assignments
67 }
68 }
69 }
70}
71
72fn gen_assignments(fields: &syn::Fields) -> TokenStream {
73 let fields = fields.iter().enumerate().map(Field::from);
74 let assignments = fields.filter(|f| !f.attrs.skip).map(|f| gen_assignment(&f));
75 quote! {
76 #( #assignments )*
77 }
78}
79
80fn gen_assignment(field: &Field) -> TokenStream {
81 use syn::spanned::Spanned;
82
83 let name = &field.name;
84 if let Some(strategy) = &field.attrs.strategy {
85 quote_spanned!(strategy.span()=> #strategy(&mut self.#name, other.#name);)
86 } else {
87 quote_spanned!(field.span=> ::merge::Merge::merge(&mut self.#name, other.#name);)
88 }
89}
90
91impl From<(usize, &syn::Field)> for Field {
92 fn from(data: (usize, &syn::Field)) -> Self {
93 use syn::spanned::Spanned;
94
95 let (index, field) = data;
96 Field {
97 name: if let Some(ident) = &field.ident {
98 syn::Member::Named(ident.clone())
99 } else {
100 syn::Member::Unnamed(index.into())
101 },
102 span: field.span(),
103 attrs: field.attrs.iter().into(),
104 }
105 }
106}
107
108impl FieldAttrs {
109 fn apply(&mut self, attr: FieldAttr) {
110 match attr {
111 FieldAttr::Skip => self.skip = true,
112 FieldAttr::Strategy(path) => self.strategy = Some(path),
113 }
114 }
115}
116
117impl<'a, I: Iterator<Item = &'a syn::Attribute>> From<I> for FieldAttrs {
118 fn from(iter: I) -> Self {
119 let mut field_attrs = Self::default();
120
121 for attr in iter {
122 if !attr.path.is_ident("merge") {
123 continue;
124 }
125
126 let parser = syn::punctuated::Punctuated::<FieldAttr, Token![,]>::parse_terminated;
127 for attr in attr.parse_args_with(parser).unwrap_or_abort() {
128 field_attrs.apply(attr);
129 }
130 }
131
132 field_attrs
133 }
134}
135
136impl syn::parse::Parse for FieldAttr {
137 fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
138 let name: syn::Ident = input.parse()?;
139 if name == "skip" {
140 Ok(FieldAttr::Skip)
142 } else if name == "strategy" {
143 let _: Token![=] = input.parse()?;
144 let path: syn::Path = input.parse()?;
145 Ok(FieldAttr::Strategy(path))
146 } else {
147 abort!(name, "Unexpected attribute: {}", name)
148 }
149 }
150}