1use crate::{
2 ast::{Container, Data, Field, Variant},
3 attr::WithAttr,
4};
5use std::collections::BTreeSet;
6use syn::{punctuated::Punctuated, Ident};
7
8pub fn find_trait_bounds<'a>(orig_generics: &'a syn::Generics, cont: &mut Container<'a>) {
12 if orig_generics.params.is_empty() {
13 return;
14 }
15
16 let all_type_params = orig_generics
17 .type_params()
18 .map(|param| ¶m.ident)
19 .collect();
20
21 assert!(cont.rename_type_params.is_subset(&all_type_params));
22
23 let mut visitor = FindTyParams {
24 all_type_params,
25 relevant_type_params: cont.rename_type_params.clone(),
26 type_params_for_bound: cont.rename_type_params.clone(),
27 };
28
29 let mut field_explicit_bounds = Vec::new();
30
31 if visitor.all_type_params.len() > visitor.relevant_type_params.len() {
32 match &cont.data {
33 Data::Enum(variants) => {
34 for variant in variants {
35 let relevant_fields = variant
36 .fields
37 .iter()
38 .filter(|field| needs_jsonschema_bound(field, Some(variant)));
39
40 for field in relevant_fields {
41 field_explicit_bounds.extend(field.serde_attrs.de_bound());
42 visitor.visit_field(field);
43 }
44 }
45 }
46 Data::Struct(_, fields) => {
47 let relevant_fields = fields
48 .iter()
49 .filter(|field| needs_jsonschema_bound(field, None));
50
51 for field in relevant_fields {
52 field_explicit_bounds.extend(field.serde_attrs.de_bound());
53 visitor.visit_field(field);
54 }
55 }
56 }
57 }
58
59 cont.relevant_type_params = visitor.relevant_type_params;
60
61 let where_clause = cont.generics.make_where_clause();
62
63 if let Some(bounds) = cont.serde_attrs.de_bound() {
64 where_clause.predicates.extend(bounds.iter().cloned());
65 } else {
66 where_clause
67 .predicates
68 .extend(visitor.type_params_for_bound.into_iter().map(|ty| {
69 syn::WherePredicate::Type(syn::PredicateType {
70 lifetimes: None,
71 bounded_ty: syn::Type::Path(syn::TypePath {
72 qself: None,
73 path: syn::Path {
74 leading_colon: None,
75 segments: Punctuated::from_iter([syn::PathSegment {
76 ident: (*ty).clone(),
77 arguments: syn::PathArguments::None,
78 }]),
79 },
80 }),
81 colon_token: <Token![:]>::default(),
82 bounds: Punctuated::from_iter([syn::TypeParamBound::Trait(syn::TraitBound {
83 paren_token: None,
84 modifier: syn::TraitBoundModifier::None,
85 lifetimes: None,
86 path: parse_quote!(schemars::JsonSchema),
87 })]),
88 })
89 }));
90 }
91
92 where_clause
93 .predicates
94 .extend(field_explicit_bounds.into_iter().flatten().cloned());
95}
96
97fn needs_jsonschema_bound(field: &Field, variant: Option<&Variant>) -> bool {
98 if let Some(variant) = variant {
99 if variant.serde_attrs.skip_deserializing() && variant.serde_attrs.skip_serializing() {
100 return false;
101 }
102 }
103
104 if field.serde_attrs.skip_deserializing() && field.serde_attrs.skip_serializing() {
105 return false;
106 }
107
108 true
109}
110
111struct FindTyParams<'ast> {
112 all_type_params: BTreeSet<&'ast Ident>,
113 relevant_type_params: BTreeSet<&'ast Ident>,
114 type_params_for_bound: BTreeSet<&'ast Ident>,
115}
116
117#[allow(clippy::single_match)]
118impl FindTyParams<'_> {
119 fn visit_field(&mut self, field: &Field) {
120 match &field.attrs.with {
121 Some(WithAttr::Type(ty)) => self.visit_type(field, ty),
122 Some(WithAttr::Function(_)) => {
123 }
125 None => self.visit_type(field, &field.original.ty),
126 }
127 }
128
129 fn visit_path(&mut self, field: &Field, path: &syn::Path) {
130 if let Some(seg) = path.segments.last() {
131 if seg.ident == "PhantomData" {
132 return;
135 }
136 }
137
138 if path.leading_colon.is_none() {
139 if let Some(first_segment) = path.segments.first() {
140 let id = &first_segment.ident;
141 if let Some(id) = self.all_type_params.get(id) {
142 self.relevant_type_params.insert(id);
143 if field.serde_attrs.de_bound().is_none() {
144 self.type_params_for_bound.insert(id);
145 }
146 }
147 }
148 }
149
150 for segment in &path.segments {
151 self.visit_path_segment(field, segment);
152 }
153 }
154
155 fn visit_type(&mut self, field: &Field, ty: &syn::Type) {
156 match ty {
157 syn::Type::Array(ty) => self.visit_type(field, &ty.elem),
158 syn::Type::BareFn(ty) => {
159 for arg in &ty.inputs {
160 self.visit_type(field, &arg.ty);
161 }
162 self.visit_return_type(field, &ty.output);
163 }
164 syn::Type::Group(ty) => self.visit_type(field, &ty.elem),
165 syn::Type::ImplTrait(ty) => {
166 for bound in &ty.bounds {
167 self.visit_type_param_bound(field, bound);
168 }
169 }
170 syn::Type::Macro(ty) => self.visit_macro(field, &ty.mac),
171 syn::Type::Paren(ty) => self.visit_type(field, &ty.elem),
172 syn::Type::Path(ty) => {
173 if let Some(qself) = &ty.qself {
174 self.visit_type(field, &qself.ty);
175 }
176 self.visit_path(field, &ty.path);
177 }
178 syn::Type::Ptr(ty) => self.visit_type(field, &ty.elem),
179 syn::Type::Reference(ty) => {
180 self.visit_type(field, &ty.elem);
181 }
182 syn::Type::Slice(ty) => self.visit_type(field, &ty.elem),
183 syn::Type::TraitObject(ty) => {
184 for bound in &ty.bounds {
185 self.visit_type_param_bound(field, bound);
186 }
187 }
188 syn::Type::Tuple(ty) => {
189 for elem in &ty.elems {
190 self.visit_type(field, elem);
191 }
192 }
193 _ => {}
194 }
195 }
196
197 fn visit_path_segment(&mut self, field: &Field, segment: &syn::PathSegment) {
198 self.visit_path_arguments(field, &segment.arguments);
199 }
200
201 fn visit_path_arguments(&mut self, field: &Field, arguments: &syn::PathArguments) {
202 match arguments {
203 syn::PathArguments::None => {}
204 syn::PathArguments::AngleBracketed(arguments) => {
205 for arg in &arguments.args {
206 match arg {
207 syn::GenericArgument::Type(arg) => self.visit_type(field, arg),
208 syn::GenericArgument::AssocType(arg) => self.visit_type(field, &arg.ty),
209 _ => {}
210 }
211 }
212 }
213 syn::PathArguments::Parenthesized(arguments) => {
214 for argument in &arguments.inputs {
215 self.visit_type(field, argument);
216 }
217 self.visit_return_type(field, &arguments.output);
218 }
219 }
220 }
221
222 fn visit_return_type(&mut self, field: &Field, return_type: &syn::ReturnType) {
223 match return_type {
224 syn::ReturnType::Default => {}
225 syn::ReturnType::Type(_, output) => self.visit_type(field, output),
226 }
227 }
228
229 fn visit_type_param_bound(&mut self, field: &Field, bound: &syn::TypeParamBound) {
230 match bound {
231 syn::TypeParamBound::Trait(bound) => self.visit_path(field, &bound.path),
232 _ => {}
233 }
234 }
235
236 #[allow(clippy::unused_self)]
243 fn visit_macro(&mut self, _field: &Field, _mac: &syn::Macro) {}
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use pretty_assertions::assert_eq;
250
251 #[test]
252 fn test_enum_bounds() {
253 let input = parse_quote! {
255 #[schemars(rename = "MyEnum<{T}, {U}, {V}, {W}, {X}, {Y}, {{Z}}>")]
256 pub enum MyEnum<'a, const LEN: usize, T, U, V, W, X, Y, Z>
257 where
258 X: Trait,
259 Z: OtherTrait
260 {
261 A,
262 B(),
263 C(T),
264 D(U, (i8, V, bool)),
265 E {
266 a: W,
267 b: [&'a Option<Box<<X as Trait>::AssocType::Z>>; LEN],
268 c: Token![Z],
269 d: PhantomData<Z>,
270 #[serde(skip)]
271 e: Z,
272 },
273 #[serde(skip)]
274 F(Z),
275 }
276 };
277
278 let cont = Container::from_ast(&input).unwrap();
279
280 assert_eq!(
281 cont.generics.where_clause,
282 Some(parse_quote!(
283 where
284 X: Trait,
285 Z: OtherTrait,
286 T: schemars::JsonSchema,
287 U: schemars::JsonSchema,
288 V: schemars::JsonSchema,
289 W: schemars::JsonSchema,
290 X: schemars::JsonSchema,
291 Y: schemars::JsonSchema
292 ))
293 );
294
295 let relevant_type_params =
296 Vec::from_iter(cont.relevant_type_params.into_iter().map(Ident::to_string));
297 assert_eq!(relevant_type_params, vec!["T", "U", "V", "W", "X", "Y"]);
298 }
299}