serde_derive/
bound.rs

1use crate::internals::ast::{Container, Data};
2use crate::internals::{attr, ungroup};
3use proc_macro2::Span;
4use std::collections::HashSet;
5use syn::punctuated::{Pair, Punctuated};
6use syn::Token;
7
8// Remove the default from every type parameter because in the generated impls
9// they look like associated types: "error: associated type bindings are not
10// allowed here".
11pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
12    syn::Generics {
13        params: generics
14            .params
15            .iter()
16            .map(|param| match param {
17                syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
18                    eq_token: None,
19                    default: None,
20                    ..param.clone()
21                }),
22                _ => param.clone(),
23            })
24            .collect(),
25        ..generics.clone()
26    }
27}
28
29pub fn with_where_predicates(
30    generics: &syn::Generics,
31    predicates: &[syn::WherePredicate],
32) -> syn::Generics {
33    let mut generics = generics.clone();
34    generics
35        .make_where_clause()
36        .predicates
37        .extend(predicates.iter().cloned());
38    generics
39}
40
41pub fn with_where_predicates_from_fields(
42    cont: &Container,
43    generics: &syn::Generics,
44    from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
45) -> syn::Generics {
46    let predicates = cont
47        .data
48        .all_fields()
49        .filter_map(|field| from_field(&field.attrs))
50        .flat_map(<[syn::WherePredicate]>::to_vec);
51
52    let mut generics = generics.clone();
53    generics.make_where_clause().predicates.extend(predicates);
54    generics
55}
56
57pub fn with_where_predicates_from_variants(
58    cont: &Container,
59    generics: &syn::Generics,
60    from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>,
61) -> syn::Generics {
62    let variants = match &cont.data {
63        Data::Enum(variants) => variants,
64        Data::Struct(_, _) => {
65            return generics.clone();
66        }
67    };
68
69    let predicates = variants
70        .iter()
71        .filter_map(|variant| from_variant(&variant.attrs))
72        .flat_map(<[syn::WherePredicate]>::to_vec);
73
74    let mut generics = generics.clone();
75    generics.make_where_clause().predicates.extend(predicates);
76    generics
77}
78
79// Puts the given bound on any generic type parameters that are used in fields
80// for which filter returns true.
81//
82// For example, the following struct needs the bound `A: Serialize, B:
83// Serialize`.
84//
85//     struct S<'b, A, B: 'b, C> {
86//         a: A,
87//         b: Option<&'b B>
88//         #[serde(skip_serializing)]
89//         c: C,
90//     }
91pub fn with_bound(
92    cont: &Container,
93    generics: &syn::Generics,
94    filter: fn(&attr::Field, Option<&attr::Variant>) -> bool,
95    bound: &syn::Path,
96) -> syn::Generics {
97    struct FindTyParams<'ast> {
98        // Set of all generic type parameters on the current struct (A, B, C in
99        // the example). Initialized up front.
100        all_type_params: HashSet<syn::Ident>,
101
102        // Set of generic type parameters used in fields for which filter
103        // returns true (A and B in the example). Filled in as the visitor sees
104        // them.
105        relevant_type_params: HashSet<syn::Ident>,
106
107        // Fields whose type is an associated type of one of the generic type
108        // parameters.
109        associated_type_usage: Vec<&'ast syn::TypePath>,
110    }
111
112    impl<'ast> FindTyParams<'ast> {
113        fn visit_field(&mut self, field: &'ast syn::Field) {
114            if let syn::Type::Path(ty) = ungroup(&field.ty) {
115                if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
116                    if self.all_type_params.contains(&t.ident) {
117                        self.associated_type_usage.push(ty);
118                    }
119                }
120            }
121            self.visit_type(&field.ty);
122        }
123
124        fn visit_path(&mut self, path: &'ast syn::Path) {
125            if let Some(seg) = path.segments.last() {
126                if seg.ident == "PhantomData" {
127                    // Hardcoded exception, because PhantomData<T> implements
128                    // Serialize and Deserialize whether or not T implements it.
129                    return;
130                }
131            }
132            if path.leading_colon.is_none() && path.segments.len() == 1 {
133                let id = &path.segments[0].ident;
134                if self.all_type_params.contains(id) {
135                    self.relevant_type_params.insert(id.clone());
136                }
137            }
138            for segment in &path.segments {
139                self.visit_path_segment(segment);
140            }
141        }
142
143        // Everything below is simply traversing the syntax tree.
144
145        fn visit_type(&mut self, ty: &'ast syn::Type) {
146            match ty {
147                #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
148                syn::Type::Array(ty) => self.visit_type(&ty.elem),
149                syn::Type::BareFn(ty) => {
150                    for arg in &ty.inputs {
151                        self.visit_type(&arg.ty);
152                    }
153                    self.visit_return_type(&ty.output);
154                }
155                syn::Type::Group(ty) => self.visit_type(&ty.elem),
156                syn::Type::ImplTrait(ty) => {
157                    for bound in &ty.bounds {
158                        self.visit_type_param_bound(bound);
159                    }
160                }
161                syn::Type::Macro(ty) => self.visit_macro(&ty.mac),
162                syn::Type::Paren(ty) => self.visit_type(&ty.elem),
163                syn::Type::Path(ty) => {
164                    if let Some(qself) = &ty.qself {
165                        self.visit_type(&qself.ty);
166                    }
167                    self.visit_path(&ty.path);
168                }
169                syn::Type::Ptr(ty) => self.visit_type(&ty.elem),
170                syn::Type::Reference(ty) => self.visit_type(&ty.elem),
171                syn::Type::Slice(ty) => self.visit_type(&ty.elem),
172                syn::Type::TraitObject(ty) => {
173                    for bound in &ty.bounds {
174                        self.visit_type_param_bound(bound);
175                    }
176                }
177                syn::Type::Tuple(ty) => {
178                    for elem in &ty.elems {
179                        self.visit_type(elem);
180                    }
181                }
182
183                syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {}
184
185                _ => {}
186            }
187        }
188
189        fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) {
190            self.visit_path_arguments(&segment.arguments);
191        }
192
193        fn visit_path_arguments(&mut self, arguments: &'ast syn::PathArguments) {
194            match arguments {
195                syn::PathArguments::None => {}
196                syn::PathArguments::AngleBracketed(arguments) => {
197                    for arg in &arguments.args {
198                        match arg {
199                            #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
200                            syn::GenericArgument::Type(arg) => self.visit_type(arg),
201                            syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty),
202                            syn::GenericArgument::Lifetime(_)
203                            | syn::GenericArgument::Const(_)
204                            | syn::GenericArgument::AssocConst(_)
205                            | syn::GenericArgument::Constraint(_) => {}
206                            _ => {}
207                        }
208                    }
209                }
210                syn::PathArguments::Parenthesized(arguments) => {
211                    for argument in &arguments.inputs {
212                        self.visit_type(argument);
213                    }
214                    self.visit_return_type(&arguments.output);
215                }
216            }
217        }
218
219        fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) {
220            match return_type {
221                syn::ReturnType::Default => {}
222                syn::ReturnType::Type(_, output) => self.visit_type(output),
223            }
224        }
225
226        fn visit_type_param_bound(&mut self, bound: &'ast syn::TypeParamBound) {
227            match bound {
228                #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
229                syn::TypeParamBound::Trait(bound) => self.visit_path(&bound.path),
230                syn::TypeParamBound::Lifetime(_) | syn::TypeParamBound::Verbatim(_) => {}
231                _ => {}
232            }
233        }
234
235        // Type parameter should not be considered used by a macro path.
236        //
237        //     struct TypeMacro<T> {
238        //         mac: T!(),
239        //         marker: PhantomData<T>,
240        //     }
241        fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
242    }
243
244    let all_type_params = generics
245        .type_params()
246        .map(|param| param.ident.clone())
247        .collect();
248
249    let mut visitor = FindTyParams {
250        all_type_params,
251        relevant_type_params: HashSet::new(),
252        associated_type_usage: Vec::new(),
253    };
254    match &cont.data {
255        Data::Enum(variants) => {
256            for variant in variants {
257                let relevant_fields = variant
258                    .fields
259                    .iter()
260                    .filter(|field| filter(&field.attrs, Some(&variant.attrs)));
261                for field in relevant_fields {
262                    visitor.visit_field(field.original);
263                }
264            }
265        }
266        Data::Struct(_, fields) => {
267            for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
268                visitor.visit_field(field.original);
269            }
270        }
271    }
272
273    let relevant_type_params = visitor.relevant_type_params;
274    let associated_type_usage = visitor.associated_type_usage;
275    let new_predicates = generics
276        .type_params()
277        .map(|param| param.ident.clone())
278        .filter(|id| relevant_type_params.contains(id))
279        .map(|id| syn::TypePath {
280            qself: None,
281            path: id.into(),
282        })
283        .chain(associated_type_usage.into_iter().cloned())
284        .map(|bounded_ty| {
285            syn::WherePredicate::Type(syn::PredicateType {
286                lifetimes: None,
287                // the type parameter that is being bounded e.g. T
288                bounded_ty: syn::Type::Path(bounded_ty),
289                colon_token: <Token![:]>::default(),
290                // the bound e.g. Serialize
291                bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
292                    paren_token: None,
293                    modifier: syn::TraitBoundModifier::None,
294                    lifetimes: None,
295                    path: bound.clone(),
296                })]
297                .into_iter()
298                .collect(),
299            })
300        });
301
302    let mut generics = generics.clone();
303    generics
304        .make_where_clause()
305        .predicates
306        .extend(new_predicates);
307    generics
308}
309
310pub fn with_self_bound(
311    cont: &Container,
312    generics: &syn::Generics,
313    bound: &syn::Path,
314) -> syn::Generics {
315    let mut generics = generics.clone();
316    generics
317        .make_where_clause()
318        .predicates
319        .push(syn::WherePredicate::Type(syn::PredicateType {
320            lifetimes: None,
321            // the type that is being bounded e.g. MyStruct<'a, T>
322            bounded_ty: type_of_item(cont),
323            colon_token: <Token![:]>::default(),
324            // the bound e.g. Default
325            bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
326                paren_token: None,
327                modifier: syn::TraitBoundModifier::None,
328                lifetimes: None,
329                path: bound.clone(),
330            })]
331            .into_iter()
332            .collect(),
333        }));
334    generics
335}
336
337pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics {
338    let bound = syn::Lifetime::new(lifetime, Span::call_site());
339    let def = syn::LifetimeParam {
340        attrs: Vec::new(),
341        lifetime: bound.clone(),
342        colon_token: None,
343        bounds: Punctuated::new(),
344    };
345
346    let params = Some(syn::GenericParam::Lifetime(def))
347        .into_iter()
348        .chain(generics.params.iter().cloned().map(|mut param| {
349            match &mut param {
350                syn::GenericParam::Lifetime(param) => {
351                    param.bounds.push(bound.clone());
352                }
353                syn::GenericParam::Type(param) => {
354                    param
355                        .bounds
356                        .push(syn::TypeParamBound::Lifetime(bound.clone()));
357                }
358                syn::GenericParam::Const(_) => {}
359            }
360            param
361        }))
362        .collect();
363
364    syn::Generics {
365        params,
366        ..generics.clone()
367    }
368}
369
370fn type_of_item(cont: &Container) -> syn::Type {
371    syn::Type::Path(syn::TypePath {
372        qself: None,
373        path: syn::Path {
374            leading_colon: None,
375            segments: vec![syn::PathSegment {
376                ident: cont.ident.clone(),
377                arguments: syn::PathArguments::AngleBracketed(
378                    syn::AngleBracketedGenericArguments {
379                        colon2_token: None,
380                        lt_token: <Token![<]>::default(),
381                        args: cont
382                            .generics
383                            .params
384                            .iter()
385                            .map(|param| match param {
386                                syn::GenericParam::Type(param) => {
387                                    syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
388                                        qself: None,
389                                        path: param.ident.clone().into(),
390                                    }))
391                                }
392                                syn::GenericParam::Lifetime(param) => {
393                                    syn::GenericArgument::Lifetime(param.lifetime.clone())
394                                }
395                                syn::GenericParam::Const(_) => {
396                                    panic!("Serde does not support const generics yet");
397                                }
398                            })
399                            .collect(),
400                        gt_token: <Token![>]>::default(),
401                    },
402                ),
403            }]
404            .into_iter()
405            .collect(),
406        },
407    })
408}