enum_dispatch/
expansion.rs

1//! Provides a utility for generating `enum_dispatch` impl blocks given `EnumDispatchItem` and
2//! `syn::ItemTrait` definitions.
3use crate::cache;
4use quote::{quote, ToTokens};
5use syn::spanned::Spanned;
6
7use crate::enum_dispatch_item::EnumDispatchItem;
8use crate::enum_dispatch_variant::EnumDispatchVariant;
9use crate::syn_utils::plain_identifier_expr;
10
11/// Name bound to the single enum field in generated match statements. It doesn't really matter
12/// what this is, as long as it's consistent across the left and right sides of generated match
13/// arms. For simplicity's sake, the field is bound to this name everywhere it's generated.
14const FIELDNAME: &str = "inner";
15
16/// Implements the specified trait for the given enum definition, assuming the trait definition is
17/// already present in local storage.
18pub fn add_enum_impls(
19    enum_def: EnumDispatchItem,
20    traitdef: syn::ItemTrait,
21) -> proc_macro2::TokenStream {
22    let traitname = traitdef.ident;
23    let traitfns = traitdef.items;
24
25    let (generic_impl_constraints, enum_type_generics, where_clause) =
26        enum_def.generics.split_for_impl();
27    let (_, trait_type_generics, _) = traitdef.generics.split_for_impl();
28
29    let enumname = &enum_def.ident.to_owned();
30    let trait_impl = quote! {
31        impl #generic_impl_constraints #traitname #trait_type_generics for #enumname #enum_type_generics #where_clause {
32
33        }
34    };
35    let mut trait_impl: syn::ItemImpl = syn::parse(trait_impl.into()).unwrap();
36
37    trait_impl.unsafety = traitdef.unsafety;
38
39    let variants: Vec<&EnumDispatchVariant> = enum_def.variants.iter().collect();
40
41    for trait_fn in traitfns {
42        trait_impl.items.push(create_trait_match(
43            trait_fn,
44            &trait_type_generics,
45            &traitname,
46            &enum_def.ident,
47            &variants,
48        ));
49    }
50
51    let mut impls = proc_macro2::TokenStream::new();
52
53    // Only generate From impls once per enum_def
54    if !cache::conversion_impls_def_by_enum(
55        &enum_def.ident,
56        enum_def.generics.type_params().count(),
57    ) {
58        let from_impls = generate_from_impls(&enum_def.ident, &variants, &enum_def.generics);
59        for from_impl in from_impls.iter() {
60            from_impl.to_tokens(&mut impls);
61        }
62
63        let try_into_impls =
64            generate_try_into_impls(&enum_def.ident, &variants, &trait_impl.generics);
65        for try_into_impl in try_into_impls.iter() {
66            try_into_impl.to_tokens(&mut impls);
67        }
68        cache::cache_enum_conversion_impls_defined(
69            enum_def.ident.clone(),
70            enum_def.generics.type_params().count(),
71        );
72    }
73
74    trait_impl.to_tokens(&mut impls);
75    impls
76}
77
78/// Returns whether or not an attribute from an enum variant should be applied to other usages of
79/// that variant's identifier.
80fn use_attribute(attr: &&syn::Attribute) -> bool {
81    attr.path().is_ident("cfg")
82}
83
84/// Generates impls of core::convert::From for each enum variant.
85fn generate_from_impls(
86    enumname: &syn::Ident,
87    enumvariants: &[&EnumDispatchVariant],
88    generics: &syn::Generics,
89) -> Vec<syn::ItemImpl> {
90    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
91    enumvariants
92        .iter()
93        .map(|variant| {
94            let variant_name = &variant.ident;
95            let variant_type = &variant.ty;
96            let attributes = &variant.attrs.iter().filter(use_attribute).collect::<Vec<_>>();
97            let impl_block = quote! {
98                #(#attributes)*
99                impl #impl_generics ::core::convert::From<#variant_type> for #enumname #ty_generics #where_clause {
100                    fn from(v: #variant_type) -> #enumname #ty_generics {
101                        #enumname::#variant_name(v)
102                    }
103                }
104            };
105            syn::parse(impl_block.into()).unwrap()
106        }).collect()
107}
108
109/// Generates impls of core::convert::TryInto for each enum variant.
110fn generate_try_into_impls(
111    enumname: &syn::Ident,
112    enumvariants: &[&EnumDispatchVariant],
113    generics: &syn::Generics,
114) -> Vec<syn::ItemImpl> {
115    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
116    enumvariants
117        .iter()
118        .enumerate()
119        .map(|(i, variant)| {
120            let variant_name = &variant.ident;
121            let variant_type = &variant.ty;
122            let attributes = &variant.attrs.iter().filter(use_attribute).collect::<Vec<_>>();
123
124            // Instead of making a specific match arm for each of the other variants we could just
125            // use a catch-all wildcard, but doing it this way means we get nicer error messages
126            // that say what the wrong variant is. It also degrades nicely in the case of a single
127            // variant enum so we don't get an unsightly "unreachable pattern" warning.
128            let other = enumvariants
129                .iter()
130                .enumerate()
131                .filter_map(
132                    |(j, other)| if i != j { Some(other) } else { None });
133            let other_attributes = other
134                .clone()
135                .map(|other| {
136                    let attrs = other.attrs.iter().filter(use_attribute);
137                    quote! { #(#attrs)* }
138                });
139            let other_idents = other
140                .map(|other| other.ident.clone());
141            let from_str = other_idents.clone().map(|ident| ident.to_string());
142            let to_str = core::iter::repeat(variant_name.to_string());
143            let repeated = core::iter::repeat(&enumname);
144
145            let impl_block = quote! {
146                #(#attributes)*
147                impl #impl_generics ::core::convert::TryInto<#variant_type> for #enumname #ty_generics #where_clause {
148                    type Error = &'static str;
149                    fn try_into(self) -> ::core::result::Result<#variant_type, <Self as ::core::convert::TryInto<#variant_type>>::Error> {
150                        match self {
151                            #enumname::#variant_name(v) => {Ok(v)},
152                            #(  #other_attributes
153                                #repeated::#other_idents(v) => {
154                                Err(concat!("Tried to convert variant ",
155                                            #from_str, " to ", #to_str))}    ),*
156                        }
157                    }
158                }
159            };
160            syn::parse(impl_block.into()).unwrap()
161        }).collect()
162}
163
164/// Used to keep track of the 'self' arguments in a trait's function signature.
165/// Static -> no 'self' arguments
166/// ByReference -> &self, &mut self
167/// ByValue -> self, mut self
168enum MethodType {
169    Static,
170    ByReference,
171    ByValue,
172}
173
174/// Parses the arguments of a trait method's signature, returning all non-self arguments as well as
175/// a MethodType enum describing the self argument, if present.
176fn extract_fn_args(
177    trait_args: syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>,
178) -> (
179    MethodType,
180    syn::punctuated::Punctuated<syn::Expr, syn::token::Comma>,
181) {
182    let mut method_type = MethodType::Static;
183    let new_args: Vec<syn::Ident> = trait_args
184        .iter()
185        .filter_map(|arg| match arg {
186            syn::FnArg::Receiver(syn::Receiver {
187                reference: Some(_), ..
188            }) => {
189                method_type = MethodType::ByReference;
190                None
191            }
192            syn::FnArg::Receiver(syn::Receiver {
193                reference: None, ..
194            }) => {
195                method_type = MethodType::ByValue;
196                None
197            }
198            syn::FnArg::Typed(syn::PatType { pat, .. }) => {
199                if let syn::Pat::Ident(syn::PatIdent { ident, .. }) = &**pat {
200                    Some(ident.to_owned())
201                } else {
202                    // All non-ident fn args are replaced in `identify_signature_arguments`.
203                    unreachable!()
204                }
205            }
206        })
207        .collect();
208    let args = {
209        let mut args = syn::punctuated::Punctuated::new();
210        new_args.iter().for_each(|arg| {
211            args.push(syn::parse_str(arg.to_string().as_str()).unwrap());
212        });
213        args
214    };
215    (method_type, args)
216}
217
218/// Creates a method call that can be used in the match arms of all non-static method
219/// implementations.
220fn create_trait_fn_call(
221    trait_method: &syn::TraitItemFn,
222    trait_generics: &syn::TypeGenerics,
223    trait_name: &syn::Ident,
224) -> syn::Expr {
225    let trait_args = trait_method.to_owned().sig.inputs;
226    let (method_type, mut args) = extract_fn_args(trait_args);
227
228    // Insert FIELDNAME at the beginning of the argument list for UCFS-style method calling
229    let explicit_self_arg = syn::Ident::new(FIELDNAME, trait_method.span());
230    args.insert(0, plain_identifier_expr(explicit_self_arg));
231
232    let mut call = syn::Expr::from(syn::ExprCall {
233        attrs: vec![],
234        func: {
235            if let MethodType::Static = method_type {
236                // Trait calls can be created when the inner type is known, like this:
237                //
238                // syn::parse_quote! { #type::#trait_method_name }
239                //
240                // However, without a concrete enum to match on, it's impossible to tell
241                // which variant to call.
242                unimplemented!(
243                    "Static methods cannot be enum_dispatched (no self argument to match on)"
244                );
245            } else {
246                let method_name = &trait_method.sig.ident;
247                let trait_turbofish = trait_generics.as_turbofish();
248
249                // It's not allowed to specify late bound lifetime arguments for a function call.
250                // Theoretically, it should be possible to determine from a function signature
251                // whether or not it has late bound lifetime arguments. In practice, it's very
252                // difficult, requiring recursive visitors over all the types in the signature and
253                // inference for elided lifetimes.
254                //
255                // Instead, it appears to be safe to strip out any lifetime arguments altogether.
256                let mut generics_without_lifetimes = trait_method.sig.generics.clone();
257                generics_without_lifetimes.params = generics_without_lifetimes
258                    .params
259                    .into_iter()
260                    .filter(|param| !matches!(param, syn::GenericParam::Lifetime(..)))
261                    .collect();
262                let method_type_generics = generics_without_lifetimes.split_for_impl().1;
263                let method_turbofish = method_type_generics.as_turbofish();
264
265                Box::new(
266                    syn::parse_quote! { #trait_name#trait_turbofish::#method_name#method_turbofish },
267                )
268            }
269        },
270        paren_token: Default::default(),
271        args,
272    });
273
274    if trait_method.sig.asyncness.is_some() {
275        call = syn::Expr::from(syn::ExprAwait {
276            attrs: Default::default(),
277            base: Box::new(call),
278            dot_token: Default::default(),
279            await_token: Default::default(),
280        });
281    }
282
283    call
284}
285
286/// Constructs a match expression that matches on all variants of the specified enum, creating a
287/// binding to their single field and calling the provided trait method on each.
288fn create_match_expr(
289    trait_method: &syn::TraitItemFn,
290    trait_generics: &syn::TypeGenerics,
291    trait_name: &syn::Ident,
292    enum_name: &syn::Ident,
293    enumvariants: &[&EnumDispatchVariant],
294) -> syn::Expr {
295    let trait_fn_call = create_trait_fn_call(trait_method, trait_generics, trait_name);
296
297    let is_self_return = if let syn::ReturnType::Type(_, returntype) = &trait_method.sig.output {
298        match returntype.as_ref() {
299            syn::Type::Path(p) => {
300                if let Some(i) = p.path.get_ident() {
301                    i.to_string() == "Self"
302                } else {
303                    false
304                }
305            }
306            _ => false,
307        }
308    } else {
309        false
310    };
311
312    // Creates a Vec containing a match arm for every enum variant
313    let match_arms = enumvariants
314        .iter()
315        .map(|variant| {
316            let mut call = trait_fn_call.to_owned();
317
318            if is_self_return {
319                let variant_type = &variant.ty;
320                let from_call: syn::ExprCall = syn::parse_quote! {
321                    <Self as ::core::convert::From::<#variant_type>>::from(#call)
322                };
323                call = syn::Expr::from(from_call);
324            }
325
326            let variant_name = &variant.ident;
327            let attrs = variant
328                .attrs
329                .iter()
330                .filter(use_attribute)
331                .cloned()
332                .collect::<Vec<_>>();
333            syn::Arm {
334                attrs,
335                pat: {
336                    let fieldname = syn::Ident::new(FIELDNAME, variant.span());
337                    syn::parse_quote! {#enum_name::#variant_name(#fieldname)}
338                },
339                guard: None,
340                fat_arrow_token: Default::default(),
341                body: Box::new(call),
342                comma: Some(Default::default()),
343            }
344        })
345        .collect();
346
347    // Creates the match expression
348    syn::Expr::from(syn::ExprMatch {
349        attrs: vec![],
350        match_token: Default::default(),
351        expr: Box::new(plain_identifier_expr(syn::Ident::new(
352            "self",
353            proc_macro2::Span::call_site(),
354        ))),
355        brace_token: Default::default(),
356        arms: match_arms,
357    })
358}
359
360/// Builds an implementation of the given trait function for the given enum type.
361fn create_trait_match(
362    trait_item: syn::TraitItem,
363    trait_generics: &syn::TypeGenerics,
364    trait_name: &syn::Ident,
365    enum_name: &syn::Ident,
366    enumvariants: &[&EnumDispatchVariant],
367) -> syn::ImplItem {
368    match trait_item {
369        syn::TraitItem::Fn(mut trait_method) => {
370            identify_signature_arguments(&mut trait_method.sig);
371
372            let match_expr = create_match_expr(
373                &trait_method,
374                trait_generics,
375                trait_name,
376                enum_name,
377                enumvariants,
378            );
379
380            let mut impl_attrs = trait_method.attrs.clone();
381            // Inline impls - #[inline] is never already specified in a trait method signature
382            impl_attrs.push(syn::Attribute {
383                pound_token: Default::default(),
384                style: syn::AttrStyle::Outer,
385                bracket_token: Default::default(),
386                meta: syn::Meta::Path(syn::parse_str("inline").unwrap()),
387            });
388
389            syn::ImplItem::Fn(syn::ImplItemFn {
390                attrs: impl_attrs,
391                vis: syn::Visibility::Inherited,
392                defaultness: None,
393                sig: trait_method.sig,
394                block: syn::Block {
395                    brace_token: Default::default(),
396                    stmts: vec![syn::Stmt::Expr(match_expr, None)],
397                },
398            })
399        }
400        _ => panic!("Unsupported trait item"),
401    }
402}
403
404/// All method arguments that appear in trait method signatures must be passed through to the
405/// underlying dispatched method calls, so they must have unique identifiers. That means we need to
406/// give names to wildcard arguments (`_`), tuple-style arguments, and a bunch of other argument
407/// types you never knew were valid Rust syntax.
408///
409/// Since there is no way to generate hygienic identifiers, we just use a special underscored
410/// string followed by an incrementing counter. We do this for *every* argument, including ones
411/// that are already named, in case somebody clever decides to name their arguments similarly.
412fn identify_signature_arguments(sig: &mut syn::Signature) {
413    let mut arg_counter = 0;
414
415    /// Generates a new argument identifier named `__enum_dispatch_arg_` followed by an
416    /// incrementing counter.
417    fn new_arg_ident(span: proc_macro2::Span, arg_counter: &mut usize) -> syn::Ident {
418        let ident = proc_macro2::Ident::new(&format!("__enum_dispatch_arg_{}", arg_counter), span);
419        *arg_counter += 1;
420        ident
421    }
422
423    sig.inputs.iter_mut().for_each(|arg| match arg {
424        syn::FnArg::Typed(ref mut pat_type) => {
425            let span = pat_type.span();
426            *pat_type.pat = match &*pat_type.pat {
427                syn::Pat::Ident(ref pat_ident) => syn::Pat::Ident(syn::PatIdent {
428                    ident: new_arg_ident(pat_ident.span(), &mut arg_counter),
429                    ..pat_ident.clone()
430                }),
431                // Some of these aren't valid Rust syntax, but why not support all of them anyways!
432                syn::Pat::Lit(syn::PatLit { attrs, .. })
433                | syn::Pat::Macro(syn::PatMacro { attrs, .. })
434                | syn::Pat::Or(syn::PatOr { attrs, .. })
435                | syn::Pat::Path(syn::PatPath { attrs, .. })
436                | syn::Pat::Range(syn::PatRange { attrs, .. })
437                | syn::Pat::Reference(syn::PatReference { attrs, .. })
438                | syn::Pat::Rest(syn::PatRest { attrs, .. })
439                | syn::Pat::Slice(syn::PatSlice { attrs, .. })
440                | syn::Pat::Struct(syn::PatStruct { attrs, .. })
441                | syn::Pat::Tuple(syn::PatTuple { attrs, .. })
442                | syn::Pat::TupleStruct(syn::PatTupleStruct { attrs, .. })
443                | syn::Pat::Type(syn::PatType { attrs, .. })
444                | syn::Pat::Const(syn::PatConst { attrs, .. })
445                | syn::Pat::Paren(syn::PatParen { attrs, .. })
446                | syn::Pat::Wild(syn::PatWild { attrs, .. }) => syn::Pat::Ident(syn::PatIdent {
447                    attrs: attrs.to_owned(),
448                    by_ref: None,
449                    mutability: None,
450                    ident: new_arg_ident(span, &mut arg_counter),
451                    subpat: None,
452                }),
453                // This can occur for `box foo` syntax, which is no longer supported by syn 2.0.
454                syn::Pat::Verbatim(_) => syn::Pat::Ident(syn::PatIdent {
455                    attrs: Default::default(),
456                    by_ref: None,
457                    mutability: None,
458                    ident: new_arg_ident(span, &mut arg_counter),
459                    subpat: None,
460                }),
461                _ => panic!("Unsupported argument type"),
462            }
463        }
464        // `self` arguments will never need to be renamed.
465        syn::FnArg::Receiver(..) => (),
466    });
467}