schemars_derive/
bound.rs

1use crate::{
2    ast::{Container, Data, Field, Variant},
3    attr::WithAttr,
4};
5use std::collections::BTreeSet;
6use syn::{punctuated::Punctuated, Ident};
7
8// This logic is heavily based on serde_derive:
9// https://github.com/serde-rs/serde/blob/a1ddb18c92f32d64b2ccaf31ddd776e56be34ba2/serde_derive/src/bound.rs#L91
10
11pub fn find_trait_bounds<'a>(orig_generics: &'a syn::Generics, cont: &mut Container<'a>) {
12    if orig_generics.params.is_empty() {
13        return;
14    }
15
16    let all_type_params = orig_generics
17        .type_params()
18        .map(|param| &param.ident)
19        .collect();
20
21    assert!(cont.rename_type_params.is_subset(&all_type_params));
22
23    let mut visitor = FindTyParams {
24        all_type_params,
25        relevant_type_params: cont.rename_type_params.clone(),
26        type_params_for_bound: cont.rename_type_params.clone(),
27    };
28
29    let mut field_explicit_bounds = Vec::new();
30
31    if visitor.all_type_params.len() > visitor.relevant_type_params.len() {
32        match &cont.data {
33            Data::Enum(variants) => {
34                for variant in variants {
35                    let relevant_fields = variant
36                        .fields
37                        .iter()
38                        .filter(|field| needs_jsonschema_bound(field, Some(variant)));
39
40                    for field in relevant_fields {
41                        field_explicit_bounds.extend(field.serde_attrs.de_bound());
42                        visitor.visit_field(field);
43                    }
44                }
45            }
46            Data::Struct(_, fields) => {
47                let relevant_fields = fields
48                    .iter()
49                    .filter(|field| needs_jsonschema_bound(field, None));
50
51                for field in relevant_fields {
52                    field_explicit_bounds.extend(field.serde_attrs.de_bound());
53                    visitor.visit_field(field);
54                }
55            }
56        }
57    }
58
59    cont.relevant_type_params = visitor.relevant_type_params;
60
61    let where_clause = cont.generics.make_where_clause();
62
63    if let Some(bounds) = cont.serde_attrs.de_bound() {
64        where_clause.predicates.extend(bounds.iter().cloned());
65    } else {
66        where_clause
67            .predicates
68            .extend(visitor.type_params_for_bound.into_iter().map(|ty| {
69                syn::WherePredicate::Type(syn::PredicateType {
70                    lifetimes: None,
71                    bounded_ty: syn::Type::Path(syn::TypePath {
72                        qself: None,
73                        path: syn::Path {
74                            leading_colon: None,
75                            segments: Punctuated::from_iter([syn::PathSegment {
76                                ident: (*ty).clone(),
77                                arguments: syn::PathArguments::None,
78                            }]),
79                        },
80                    }),
81                    colon_token: <Token![:]>::default(),
82                    bounds: Punctuated::from_iter([syn::TypeParamBound::Trait(syn::TraitBound {
83                        paren_token: None,
84                        modifier: syn::TraitBoundModifier::None,
85                        lifetimes: None,
86                        path: parse_quote!(schemars::JsonSchema),
87                    })]),
88                })
89            }));
90    }
91
92    where_clause
93        .predicates
94        .extend(field_explicit_bounds.into_iter().flatten().cloned());
95}
96
97fn needs_jsonschema_bound(field: &Field, variant: Option<&Variant>) -> bool {
98    if let Some(variant) = variant {
99        if variant.serde_attrs.skip_deserializing() && variant.serde_attrs.skip_serializing() {
100            return false;
101        }
102    }
103
104    if field.serde_attrs.skip_deserializing() && field.serde_attrs.skip_serializing() {
105        return false;
106    }
107
108    true
109}
110
111struct FindTyParams<'ast> {
112    all_type_params: BTreeSet<&'ast Ident>,
113    relevant_type_params: BTreeSet<&'ast Ident>,
114    type_params_for_bound: BTreeSet<&'ast Ident>,
115}
116
117#[allow(clippy::single_match)]
118impl FindTyParams<'_> {
119    fn visit_field(&mut self, field: &Field) {
120        match &field.attrs.with {
121            Some(WithAttr::Type(ty)) => self.visit_type(field, ty),
122            Some(WithAttr::Function(_)) => {
123                // `schema_with` function type params may or may not implement `JsonSchema`
124            }
125            None => self.visit_type(field, &field.original.ty),
126        }
127    }
128
129    fn visit_path(&mut self, field: &Field, path: &syn::Path) {
130        if let Some(seg) = path.segments.last() {
131            if seg.ident == "PhantomData" {
132                // Hardcoded exception, because PhantomData<T> implements
133                // JsonSchema whether or not T implements it.
134                return;
135            }
136        }
137
138        if path.leading_colon.is_none() {
139            if let Some(first_segment) = path.segments.first() {
140                let id = &first_segment.ident;
141                if let Some(id) = self.all_type_params.get(id) {
142                    self.relevant_type_params.insert(id);
143                    if field.serde_attrs.de_bound().is_none() {
144                        self.type_params_for_bound.insert(id);
145                    }
146                }
147            }
148        }
149
150        for segment in &path.segments {
151            self.visit_path_segment(field, segment);
152        }
153    }
154
155    fn visit_type(&mut self, field: &Field, ty: &syn::Type) {
156        match ty {
157            syn::Type::Array(ty) => self.visit_type(field, &ty.elem),
158            syn::Type::BareFn(ty) => {
159                for arg in &ty.inputs {
160                    self.visit_type(field, &arg.ty);
161                }
162                self.visit_return_type(field, &ty.output);
163            }
164            syn::Type::Group(ty) => self.visit_type(field, &ty.elem),
165            syn::Type::ImplTrait(ty) => {
166                for bound in &ty.bounds {
167                    self.visit_type_param_bound(field, bound);
168                }
169            }
170            syn::Type::Macro(ty) => self.visit_macro(field, &ty.mac),
171            syn::Type::Paren(ty) => self.visit_type(field, &ty.elem),
172            syn::Type::Path(ty) => {
173                if let Some(qself) = &ty.qself {
174                    self.visit_type(field, &qself.ty);
175                }
176                self.visit_path(field, &ty.path);
177            }
178            syn::Type::Ptr(ty) => self.visit_type(field, &ty.elem),
179            syn::Type::Reference(ty) => {
180                self.visit_type(field, &ty.elem);
181            }
182            syn::Type::Slice(ty) => self.visit_type(field, &ty.elem),
183            syn::Type::TraitObject(ty) => {
184                for bound in &ty.bounds {
185                    self.visit_type_param_bound(field, bound);
186                }
187            }
188            syn::Type::Tuple(ty) => {
189                for elem in &ty.elems {
190                    self.visit_type(field, elem);
191                }
192            }
193            _ => {}
194        }
195    }
196
197    fn visit_path_segment(&mut self, field: &Field, segment: &syn::PathSegment) {
198        self.visit_path_arguments(field, &segment.arguments);
199    }
200
201    fn visit_path_arguments(&mut self, field: &Field, arguments: &syn::PathArguments) {
202        match arguments {
203            syn::PathArguments::None => {}
204            syn::PathArguments::AngleBracketed(arguments) => {
205                for arg in &arguments.args {
206                    match arg {
207                        syn::GenericArgument::Type(arg) => self.visit_type(field, arg),
208                        syn::GenericArgument::AssocType(arg) => self.visit_type(field, &arg.ty),
209                        _ => {}
210                    }
211                }
212            }
213            syn::PathArguments::Parenthesized(arguments) => {
214                for argument in &arguments.inputs {
215                    self.visit_type(field, argument);
216                }
217                self.visit_return_type(field, &arguments.output);
218            }
219        }
220    }
221
222    fn visit_return_type(&mut self, field: &Field, return_type: &syn::ReturnType) {
223        match return_type {
224            syn::ReturnType::Default => {}
225            syn::ReturnType::Type(_, output) => self.visit_type(field, output),
226        }
227    }
228
229    fn visit_type_param_bound(&mut self, field: &Field, bound: &syn::TypeParamBound) {
230        match bound {
231            syn::TypeParamBound::Trait(bound) => self.visit_path(field, &bound.path),
232            _ => {}
233        }
234    }
235
236    // Type parameter should not be considered used by a macro path.
237    //
238    //     struct TypeMacro<T> {
239    //         mac: T!(),
240    //         marker: PhantomData<T>,
241    //     }
242    #[allow(clippy::unused_self)]
243    fn visit_macro(&mut self, _field: &Field, _mac: &syn::Macro) {}
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use pretty_assertions::assert_eq;
250
251    #[test]
252    fn test_enum_bounds() {
253        // All type params should be included in `JsonSchema` trait bounds except `Z`
254        let input = parse_quote! {
255            #[schemars(rename = "MyEnum<{T}, {U}, {V}, {W}, {X}, {Y}, {{Z}}>")]
256            pub enum MyEnum<'a, const LEN: usize, T, U, V, W, X, Y, Z>
257            where
258                X: Trait,
259                Z: OtherTrait
260            {
261                A,
262                B(),
263                C(T),
264                D(U, (i8, V, bool)),
265                E {
266                    a: W,
267                    b: [&'a Option<Box<<X as Trait>::AssocType::Z>>; LEN],
268                    c: Token![Z],
269                    d: PhantomData<Z>,
270                    #[serde(skip)]
271                    e: Z,
272                },
273                #[serde(skip)]
274                F(Z),
275            }
276        };
277
278        let cont = Container::from_ast(&input).unwrap();
279
280        assert_eq!(
281            cont.generics.where_clause,
282            Some(parse_quote!(
283                where
284                    X: Trait,
285                    Z: OtherTrait,
286                    T: schemars::JsonSchema,
287                    U: schemars::JsonSchema,
288                    V: schemars::JsonSchema,
289                    W: schemars::JsonSchema,
290                    X: schemars::JsonSchema,
291                    Y: schemars::JsonSchema
292            ))
293        );
294
295        let relevant_type_params =
296            Vec::from_iter(cont.relevant_type_params.into_iter().map(Ident::to_string));
297        assert_eq!(relevant_type_params, vec!["T", "U", "V", "W", "X", "Y"]);
298    }
299}