1use crate::cache;
4use quote::{quote, ToTokens};
5use syn::spanned::Spanned;
6
7use crate::enum_dispatch_item::EnumDispatchItem;
8use crate::enum_dispatch_variant::EnumDispatchVariant;
9use crate::syn_utils::plain_identifier_expr;
10
11const FIELDNAME: &str = "inner";
15
16pub fn add_enum_impls(
19 enum_def: EnumDispatchItem,
20 traitdef: syn::ItemTrait,
21) -> proc_macro2::TokenStream {
22 let traitname = traitdef.ident;
23 let traitfns = traitdef.items;
24
25 let (generic_impl_constraints, enum_type_generics, where_clause) =
26 enum_def.generics.split_for_impl();
27 let (_, trait_type_generics, _) = traitdef.generics.split_for_impl();
28
29 let enumname = &enum_def.ident.to_owned();
30 let trait_impl = quote! {
31 impl #generic_impl_constraints #traitname #trait_type_generics for #enumname #enum_type_generics #where_clause {
32
33 }
34 };
35 let mut trait_impl: syn::ItemImpl = syn::parse(trait_impl.into()).unwrap();
36
37 trait_impl.unsafety = traitdef.unsafety;
38
39 let variants: Vec<&EnumDispatchVariant> = enum_def.variants.iter().collect();
40
41 for trait_fn in traitfns {
42 trait_impl.items.push(create_trait_match(
43 trait_fn,
44 &trait_type_generics,
45 &traitname,
46 &enum_def.ident,
47 &variants,
48 ));
49 }
50
51 let mut impls = proc_macro2::TokenStream::new();
52
53 if !cache::conversion_impls_def_by_enum(
55 &enum_def.ident,
56 enum_def.generics.type_params().count(),
57 ) {
58 let from_impls = generate_from_impls(&enum_def.ident, &variants, &enum_def.generics);
59 for from_impl in from_impls.iter() {
60 from_impl.to_tokens(&mut impls);
61 }
62
63 let try_into_impls =
64 generate_try_into_impls(&enum_def.ident, &variants, &trait_impl.generics);
65 for try_into_impl in try_into_impls.iter() {
66 try_into_impl.to_tokens(&mut impls);
67 }
68 cache::cache_enum_conversion_impls_defined(
69 enum_def.ident.clone(),
70 enum_def.generics.type_params().count(),
71 );
72 }
73
74 trait_impl.to_tokens(&mut impls);
75 impls
76}
77
78fn use_attribute(attr: &&syn::Attribute) -> bool {
81 attr.path().is_ident("cfg")
82}
83
84fn generate_from_impls(
86 enumname: &syn::Ident,
87 enumvariants: &[&EnumDispatchVariant],
88 generics: &syn::Generics,
89) -> Vec<syn::ItemImpl> {
90 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
91 enumvariants
92 .iter()
93 .map(|variant| {
94 let variant_name = &variant.ident;
95 let variant_type = &variant.ty;
96 let attributes = &variant.attrs.iter().filter(use_attribute).collect::<Vec<_>>();
97 let impl_block = quote! {
98 #(#attributes)*
99 impl #impl_generics ::core::convert::From<#variant_type> for #enumname #ty_generics #where_clause {
100 fn from(v: #variant_type) -> #enumname #ty_generics {
101 #enumname::#variant_name(v)
102 }
103 }
104 };
105 syn::parse(impl_block.into()).unwrap()
106 }).collect()
107}
108
109fn generate_try_into_impls(
111 enumname: &syn::Ident,
112 enumvariants: &[&EnumDispatchVariant],
113 generics: &syn::Generics,
114) -> Vec<syn::ItemImpl> {
115 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
116 enumvariants
117 .iter()
118 .enumerate()
119 .map(|(i, variant)| {
120 let variant_name = &variant.ident;
121 let variant_type = &variant.ty;
122 let attributes = &variant.attrs.iter().filter(use_attribute).collect::<Vec<_>>();
123
124 let other = enumvariants
129 .iter()
130 .enumerate()
131 .filter_map(
132 |(j, other)| if i != j { Some(other) } else { None });
133 let other_attributes = other
134 .clone()
135 .map(|other| {
136 let attrs = other.attrs.iter().filter(use_attribute);
137 quote! { #(#attrs)* }
138 });
139 let other_idents = other
140 .map(|other| other.ident.clone());
141 let from_str = other_idents.clone().map(|ident| ident.to_string());
142 let to_str = core::iter::repeat(variant_name.to_string());
143 let repeated = core::iter::repeat(&enumname);
144
145 let impl_block = quote! {
146 #(#attributes)*
147 impl #impl_generics ::core::convert::TryInto<#variant_type> for #enumname #ty_generics #where_clause {
148 type Error = &'static str;
149 fn try_into(self) -> ::core::result::Result<#variant_type, <Self as ::core::convert::TryInto<#variant_type>>::Error> {
150 match self {
151 #enumname::#variant_name(v) => {Ok(v)},
152 #( #other_attributes
153 #repeated::#other_idents(v) => {
154 Err(concat!("Tried to convert variant ",
155 #from_str, " to ", #to_str))} ),*
156 }
157 }
158 }
159 };
160 syn::parse(impl_block.into()).unwrap()
161 }).collect()
162}
163
164enum MethodType {
169 Static,
170 ByReference,
171 ByValue,
172}
173
174fn extract_fn_args(
177 trait_args: syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>,
178) -> (
179 MethodType,
180 syn::punctuated::Punctuated<syn::Expr, syn::token::Comma>,
181) {
182 let mut method_type = MethodType::Static;
183 let new_args: Vec<syn::Ident> = trait_args
184 .iter()
185 .filter_map(|arg| match arg {
186 syn::FnArg::Receiver(syn::Receiver {
187 reference: Some(_), ..
188 }) => {
189 method_type = MethodType::ByReference;
190 None
191 }
192 syn::FnArg::Receiver(syn::Receiver {
193 reference: None, ..
194 }) => {
195 method_type = MethodType::ByValue;
196 None
197 }
198 syn::FnArg::Typed(syn::PatType { pat, .. }) => {
199 if let syn::Pat::Ident(syn::PatIdent { ident, .. }) = &**pat {
200 Some(ident.to_owned())
201 } else {
202 unreachable!()
204 }
205 }
206 })
207 .collect();
208 let args = {
209 let mut args = syn::punctuated::Punctuated::new();
210 new_args.iter().for_each(|arg| {
211 args.push(syn::parse_str(arg.to_string().as_str()).unwrap());
212 });
213 args
214 };
215 (method_type, args)
216}
217
218fn create_trait_fn_call(
221 trait_method: &syn::TraitItemFn,
222 trait_generics: &syn::TypeGenerics,
223 trait_name: &syn::Ident,
224) -> syn::Expr {
225 let trait_args = trait_method.to_owned().sig.inputs;
226 let (method_type, mut args) = extract_fn_args(trait_args);
227
228 let explicit_self_arg = syn::Ident::new(FIELDNAME, trait_method.span());
230 args.insert(0, plain_identifier_expr(explicit_self_arg));
231
232 let mut call = syn::Expr::from(syn::ExprCall {
233 attrs: vec![],
234 func: {
235 if let MethodType::Static = method_type {
236 unimplemented!(
243 "Static methods cannot be enum_dispatched (no self argument to match on)"
244 );
245 } else {
246 let method_name = &trait_method.sig.ident;
247 let trait_turbofish = trait_generics.as_turbofish();
248
249 let mut generics_without_lifetimes = trait_method.sig.generics.clone();
257 generics_without_lifetimes.params = generics_without_lifetimes
258 .params
259 .into_iter()
260 .filter(|param| !matches!(param, syn::GenericParam::Lifetime(..)))
261 .collect();
262 let method_type_generics = generics_without_lifetimes.split_for_impl().1;
263 let method_turbofish = method_type_generics.as_turbofish();
264
265 Box::new(
266 syn::parse_quote! { #trait_name#trait_turbofish::#method_name#method_turbofish },
267 )
268 }
269 },
270 paren_token: Default::default(),
271 args,
272 });
273
274 if trait_method.sig.asyncness.is_some() {
275 call = syn::Expr::from(syn::ExprAwait {
276 attrs: Default::default(),
277 base: Box::new(call),
278 dot_token: Default::default(),
279 await_token: Default::default(),
280 });
281 }
282
283 call
284}
285
286fn create_match_expr(
289 trait_method: &syn::TraitItemFn,
290 trait_generics: &syn::TypeGenerics,
291 trait_name: &syn::Ident,
292 enum_name: &syn::Ident,
293 enumvariants: &[&EnumDispatchVariant],
294) -> syn::Expr {
295 let trait_fn_call = create_trait_fn_call(trait_method, trait_generics, trait_name);
296
297 let is_self_return = if let syn::ReturnType::Type(_, returntype) = &trait_method.sig.output {
298 match returntype.as_ref() {
299 syn::Type::Path(p) => {
300 if let Some(i) = p.path.get_ident() {
301 i.to_string() == "Self"
302 } else {
303 false
304 }
305 }
306 _ => false,
307 }
308 } else {
309 false
310 };
311
312 let match_arms = enumvariants
314 .iter()
315 .map(|variant| {
316 let mut call = trait_fn_call.to_owned();
317
318 if is_self_return {
319 let variant_type = &variant.ty;
320 let from_call: syn::ExprCall = syn::parse_quote! {
321 <Self as ::core::convert::From::<#variant_type>>::from(#call)
322 };
323 call = syn::Expr::from(from_call);
324 }
325
326 let variant_name = &variant.ident;
327 let attrs = variant
328 .attrs
329 .iter()
330 .filter(use_attribute)
331 .cloned()
332 .collect::<Vec<_>>();
333 syn::Arm {
334 attrs,
335 pat: {
336 let fieldname = syn::Ident::new(FIELDNAME, variant.span());
337 syn::parse_quote! {#enum_name::#variant_name(#fieldname)}
338 },
339 guard: None,
340 fat_arrow_token: Default::default(),
341 body: Box::new(call),
342 comma: Some(Default::default()),
343 }
344 })
345 .collect();
346
347 syn::Expr::from(syn::ExprMatch {
349 attrs: vec![],
350 match_token: Default::default(),
351 expr: Box::new(plain_identifier_expr(syn::Ident::new(
352 "self",
353 proc_macro2::Span::call_site(),
354 ))),
355 brace_token: Default::default(),
356 arms: match_arms,
357 })
358}
359
360fn create_trait_match(
362 trait_item: syn::TraitItem,
363 trait_generics: &syn::TypeGenerics,
364 trait_name: &syn::Ident,
365 enum_name: &syn::Ident,
366 enumvariants: &[&EnumDispatchVariant],
367) -> syn::ImplItem {
368 match trait_item {
369 syn::TraitItem::Fn(mut trait_method) => {
370 identify_signature_arguments(&mut trait_method.sig);
371
372 let match_expr = create_match_expr(
373 &trait_method,
374 trait_generics,
375 trait_name,
376 enum_name,
377 enumvariants,
378 );
379
380 let mut impl_attrs = trait_method.attrs.clone();
381 impl_attrs.push(syn::Attribute {
383 pound_token: Default::default(),
384 style: syn::AttrStyle::Outer,
385 bracket_token: Default::default(),
386 meta: syn::Meta::Path(syn::parse_str("inline").unwrap()),
387 });
388
389 syn::ImplItem::Fn(syn::ImplItemFn {
390 attrs: impl_attrs,
391 vis: syn::Visibility::Inherited,
392 defaultness: None,
393 sig: trait_method.sig,
394 block: syn::Block {
395 brace_token: Default::default(),
396 stmts: vec![syn::Stmt::Expr(match_expr, None)],
397 },
398 })
399 }
400 _ => panic!("Unsupported trait item"),
401 }
402}
403
404fn identify_signature_arguments(sig: &mut syn::Signature) {
413 let mut arg_counter = 0;
414
415 fn new_arg_ident(span: proc_macro2::Span, arg_counter: &mut usize) -> syn::Ident {
418 let ident = proc_macro2::Ident::new(&format!("__enum_dispatch_arg_{}", arg_counter), span);
419 *arg_counter += 1;
420 ident
421 }
422
423 sig.inputs.iter_mut().for_each(|arg| match arg {
424 syn::FnArg::Typed(ref mut pat_type) => {
425 let span = pat_type.span();
426 *pat_type.pat = match &*pat_type.pat {
427 syn::Pat::Ident(ref pat_ident) => syn::Pat::Ident(syn::PatIdent {
428 ident: new_arg_ident(pat_ident.span(), &mut arg_counter),
429 ..pat_ident.clone()
430 }),
431 syn::Pat::Lit(syn::PatLit { attrs, .. })
433 | syn::Pat::Macro(syn::PatMacro { attrs, .. })
434 | syn::Pat::Or(syn::PatOr { attrs, .. })
435 | syn::Pat::Path(syn::PatPath { attrs, .. })
436 | syn::Pat::Range(syn::PatRange { attrs, .. })
437 | syn::Pat::Reference(syn::PatReference { attrs, .. })
438 | syn::Pat::Rest(syn::PatRest { attrs, .. })
439 | syn::Pat::Slice(syn::PatSlice { attrs, .. })
440 | syn::Pat::Struct(syn::PatStruct { attrs, .. })
441 | syn::Pat::Tuple(syn::PatTuple { attrs, .. })
442 | syn::Pat::TupleStruct(syn::PatTupleStruct { attrs, .. })
443 | syn::Pat::Type(syn::PatType { attrs, .. })
444 | syn::Pat::Const(syn::PatConst { attrs, .. })
445 | syn::Pat::Paren(syn::PatParen { attrs, .. })
446 | syn::Pat::Wild(syn::PatWild { attrs, .. }) => syn::Pat::Ident(syn::PatIdent {
447 attrs: attrs.to_owned(),
448 by_ref: None,
449 mutability: None,
450 ident: new_arg_ident(span, &mut arg_counter),
451 subpat: None,
452 }),
453 syn::Pat::Verbatim(_) => syn::Pat::Ident(syn::PatIdent {
455 attrs: Default::default(),
456 by_ref: None,
457 mutability: None,
458 ident: new_arg_ident(span, &mut arg_counter),
459 subpat: None,
460 }),
461 _ => panic!("Unsupported argument type"),
462 }
463 }
464 syn::FnArg::Receiver(..) => (),
466 });
467}