schemars_derive/
schema_exprs.rs

1use std::collections::HashSet;
2
3use crate::{ast::*, attr::WithAttr, metadata::SchemaMetadata};
4use proc_macro2::{Span, TokenStream};
5use serde_derive_internals::ast::Style;
6use serde_derive_internals::attr::{self as serde_attr, Default as SerdeDefault, TagType};
7use syn::spanned::Spanned;
8
9pub fn expr_for_container(cont: &Container) -> TokenStream {
10    let mut schema_expr = match &cont.data {
11        Data::Struct(Style::Unit, _) => expr_for_unit_struct(),
12        Data::Struct(Style::Newtype, fields) => expr_for_newtype_struct(&fields[0]),
13        Data::Struct(Style::Tuple, fields) => expr_for_tuple_struct(fields),
14        Data::Struct(Style::Struct, fields) => expr_for_struct(
15            fields,
16            cont.serde_attrs.default(),
17            cont.serde_attrs.deny_unknown_fields(),
18        ),
19        Data::Enum(variants) => expr_for_enum(variants, &cont.serde_attrs),
20    };
21
22    cont.attrs.as_metadata().apply_to_schema(&mut schema_expr);
23    schema_expr
24}
25
26pub fn expr_for_repr(cont: &Container) -> Result<TokenStream, syn::Error> {
27    let repr_type = cont.attrs.repr.as_ref().ok_or_else(|| {
28        syn::Error::new(
29            Span::call_site(),
30            "JsonSchema_repr: missing #[repr(...)] attribute",
31        )
32    })?;
33
34    let variants = match &cont.data {
35        Data::Enum(variants) => variants,
36        _ => return Err(syn::Error::new(Span::call_site(), "oh no!")),
37    };
38
39    if let Some(non_unit_error) = variants.iter().find_map(|v| match v.style {
40        Style::Unit => None,
41        _ => Some(syn::Error::new(
42            v.original.span(),
43            "JsonSchema_repr: must be a unit variant",
44        )),
45    }) {
46        return Err(non_unit_error);
47    };
48
49    let enum_ident = &cont.ident;
50    let variant_idents = variants.iter().map(|v| &v.ident);
51
52    let mut schema_expr = schema_object(quote! {
53        instance_type: Some(schemars::schema::InstanceType::Integer.into()),
54        enum_values: Some(vec![#((#enum_ident::#variant_idents as #repr_type).into()),*]),
55    });
56
57    cont.attrs.as_metadata().apply_to_schema(&mut schema_expr);
58    Ok(schema_expr)
59}
60
61fn expr_for_field(field: &Field, allow_ref: bool) -> TokenStream {
62    let (ty, type_def) = type_for_field_schema(field);
63    let span = field.original.span();
64    let gen = quote!(gen);
65
66    let mut schema_expr = if field.validation_attrs.required() {
67        quote_spanned! {span=>
68            <#ty as schemars::JsonSchema>::_schemars_private_non_optional_json_schema(#gen)
69        }
70    } else if allow_ref {
71        quote_spanned! {span=>
72            #gen.subschema_for::<#ty>()
73        }
74    } else {
75        quote_spanned! {span=>
76            <#ty as schemars::JsonSchema>::json_schema(#gen)
77        }
78    };
79
80    prepend_type_def(type_def, &mut schema_expr);
81    field.validation_attrs.apply_to_schema(&mut schema_expr);
82
83    schema_expr
84}
85
86pub fn type_for_field_schema(field: &Field) -> (syn::Type, Option<TokenStream>) {
87    match &field.attrs.with {
88        None => (field.ty.to_owned(), None),
89        Some(with_attr) => type_for_schema(with_attr),
90    }
91}
92
93fn type_for_schema(with_attr: &WithAttr) -> (syn::Type, Option<TokenStream>) {
94    match with_attr {
95        WithAttr::Type(ty) => (ty.to_owned(), None),
96        WithAttr::Function(fun) => {
97            let ty_name = syn::Ident::new("_SchemarsSchemaWithFunction", Span::call_site());
98            let fn_name = fun.segments.last().unwrap().ident.to_string();
99
100            let type_def = quote_spanned! {fun.span()=>
101                struct #ty_name;
102
103                impl schemars::JsonSchema for #ty_name {
104                    fn is_referenceable() -> bool {
105                        false
106                    }
107
108                    fn schema_name() -> std::string::String {
109                        #fn_name.to_string()
110                    }
111
112                    fn schema_id() -> std::borrow::Cow<'static, str> {
113                        std::borrow::Cow::Borrowed(std::concat!(
114                            "_SchemarsSchemaWithFunction/",
115                            std::module_path!(),
116                            "/",
117                            #fn_name
118                        ))
119                    }
120
121                    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
122                        #fun(gen)
123                    }
124                }
125            };
126
127            (parse_quote!(#ty_name), Some(type_def))
128        }
129    }
130}
131
132fn expr_for_enum(variants: &[Variant], cattrs: &serde_attr::Container) -> TokenStream {
133    let deny_unknown_fields = cattrs.deny_unknown_fields();
134    let variants = variants
135        .iter()
136        .filter(|v| !v.serde_attrs.skip_deserializing());
137
138    match cattrs.tag() {
139        TagType::External => expr_for_external_tagged_enum(variants, deny_unknown_fields),
140        TagType::None => expr_for_untagged_enum(variants, deny_unknown_fields),
141        TagType::Internal { tag } => {
142            expr_for_internal_tagged_enum(variants, tag, deny_unknown_fields)
143        }
144        TagType::Adjacent { tag, content } => {
145            expr_for_adjacent_tagged_enum(variants, tag, content, deny_unknown_fields)
146        }
147    }
148}
149
150fn expr_for_external_tagged_enum<'a>(
151    variants: impl Iterator<Item = &'a Variant<'a>>,
152    deny_unknown_fields: bool,
153) -> TokenStream {
154    let mut unique_names = HashSet::<&str>::new();
155    let mut count = 0;
156    let (unit_variants, complex_variants): (Vec<_>, Vec<_>) = variants
157        .inspect(|v| {
158            unique_names.insert(v.name());
159            count += 1;
160        })
161        .partition(|v| v.is_unit() && v.attrs.is_default());
162    let unit_names = unit_variants.iter().map(|v| v.name());
163    let unit_schema = schema_object(quote! {
164        instance_type: Some(schemars::schema::InstanceType::String.into()),
165        enum_values: Some(vec![#(#unit_names.into()),*]),
166    });
167
168    if complex_variants.is_empty() {
169        return unit_schema;
170    }
171
172    let mut schemas = Vec::new();
173    if !unit_variants.is_empty() {
174        schemas.push(unit_schema);
175    }
176
177    schemas.extend(complex_variants.into_iter().map(|variant| {
178        let name = variant.name();
179
180        let mut schema_expr = if variant.is_unit() && variant.attrs.with.is_none() {
181            quote! {
182                schemars::_private::new_unit_enum(#name)
183            }
184        } else {
185            let sub_schema = expr_for_untagged_enum_variant(variant, deny_unknown_fields);
186            quote! {
187                schemars::_private::new_externally_tagged_enum(#name, #sub_schema)
188            }
189        };
190
191        variant
192            .attrs
193            .as_metadata()
194            .apply_to_schema(&mut schema_expr);
195
196        schema_expr
197    }));
198
199    variant_subschemas(unique_names.len() == count, schemas)
200}
201
202fn expr_for_internal_tagged_enum<'a>(
203    variants: impl Iterator<Item = &'a Variant<'a>>,
204    tag_name: &str,
205    deny_unknown_fields: bool,
206) -> TokenStream {
207    let mut unique_names = HashSet::new();
208    let mut count = 0;
209    let variant_schemas = variants
210        .map(|variant| {
211            unique_names.insert(variant.name());
212            count += 1;
213
214            let name = variant.name();
215
216            let mut tag_schema = quote! {
217                schemars::_private::new_internally_tagged_enum(#tag_name, #name, #deny_unknown_fields)
218            };
219
220            variant.attrs.as_metadata().apply_to_schema(&mut tag_schema);
221
222            if let Some(variant_schema) =
223                expr_for_untagged_enum_variant_for_flatten(variant, deny_unknown_fields)
224            {
225                tag_schema.extend(quote!(.flatten(#variant_schema)))
226            }
227
228            tag_schema
229        })
230        .collect();
231
232    variant_subschemas(unique_names.len() == count, variant_schemas)
233}
234
235fn expr_for_untagged_enum<'a>(
236    variants: impl Iterator<Item = &'a Variant<'a>>,
237    deny_unknown_fields: bool,
238) -> TokenStream {
239    let schemas = variants
240        .map(|variant| {
241            let mut schema_expr = expr_for_untagged_enum_variant(variant, deny_unknown_fields);
242
243            variant
244                .attrs
245                .as_metadata()
246                .apply_to_schema(&mut schema_expr);
247
248            schema_expr
249        })
250        .collect();
251
252    // Untagged enums can easily have variants whose schemas overlap; rather
253    // that checking the exclusivity of each subschema we simply us `any_of`.
254    variant_subschemas(false, schemas)
255}
256
257fn expr_for_adjacent_tagged_enum<'a>(
258    variants: impl Iterator<Item = &'a Variant<'a>>,
259    tag_name: &str,
260    content_name: &str,
261    deny_unknown_fields: bool,
262) -> TokenStream {
263    let mut unique_names = HashSet::new();
264    let mut count = 0;
265    let schemas = variants
266        .map(|variant| {
267            unique_names.insert(variant.name());
268            count += 1;
269
270            let content_schema = if variant.is_unit() && variant.attrs.with.is_none() {
271                None
272            } else {
273                Some(expr_for_untagged_enum_variant(variant, deny_unknown_fields))
274            };
275
276            let (add_content_to_props, add_content_to_required) = content_schema
277                .map(|content_schema| {
278                    (
279                        quote!(props.insert(#content_name.to_owned(), #content_schema);),
280                        quote!(required.insert(#content_name.to_owned());),
281                    )
282                })
283                .unwrap_or_default();
284
285            let name = variant.name();
286            let tag_schema = schema_object(quote! {
287                instance_type: Some(schemars::schema::InstanceType::String.into()),
288                enum_values: Some(vec![#name.into()]),
289            });
290
291            let set_additional_properties = if deny_unknown_fields {
292                quote! {
293                    additional_properties: Some(Box::new(false.into())),
294                }
295            } else {
296                TokenStream::new()
297            };
298
299            let mut outer_schema = schema_object(quote! {
300                instance_type: Some(schemars::schema::InstanceType::Object.into()),
301                object: Some(Box::new(schemars::schema::ObjectValidation {
302                    properties: {
303                        let mut props = schemars::Map::new();
304                        props.insert(#tag_name.to_owned(), #tag_schema);
305                        #add_content_to_props
306                        props
307                    },
308                    required: {
309                        let mut required = schemars::Set::new();
310                        required.insert(#tag_name.to_owned());
311                        #add_content_to_required
312                        required
313                    },
314                    // As we're creating a "wrapper" object, we can honor the
315                    // disposition of deny_unknown_fields.
316                    #set_additional_properties
317                    ..Default::default()
318                })),
319            });
320
321            variant
322                .attrs
323                .as_metadata()
324                .apply_to_schema(&mut outer_schema);
325
326            outer_schema
327        })
328        .collect();
329
330    variant_subschemas(unique_names.len() == count, schemas)
331}
332
333/// Callers must determine if all subschemas are mutually exclusive. This can
334/// be done for most tagging regimes by checking that all tag names are unique.
335fn variant_subschemas(unique: bool, schemas: Vec<TokenStream>) -> TokenStream {
336    if unique {
337        schema_object(quote! {
338            subschemas: Some(Box::new(schemars::schema::SubschemaValidation {
339                one_of: Some(vec![#(#schemas),*]),
340                ..Default::default()
341            })),
342        })
343    } else {
344        schema_object(quote! {
345            subschemas: Some(Box::new(schemars::schema::SubschemaValidation {
346                any_of: Some(vec![#(#schemas),*]),
347                ..Default::default()
348            })),
349        })
350    }
351}
352
353fn expr_for_untagged_enum_variant(variant: &Variant, deny_unknown_fields: bool) -> TokenStream {
354    if let Some(with_attr) = &variant.attrs.with {
355        let (ty, type_def) = type_for_schema(with_attr);
356        let gen = quote!(gen);
357        let mut schema_expr = quote_spanned! {variant.original.span()=>
358            #gen.subschema_for::<#ty>()
359        };
360
361        prepend_type_def(type_def, &mut schema_expr);
362        return schema_expr;
363    }
364
365    match variant.style {
366        Style::Unit => expr_for_unit_struct(),
367        Style::Newtype => expr_for_field(&variant.fields[0], true),
368        Style::Tuple => expr_for_tuple_struct(&variant.fields),
369        Style::Struct => expr_for_struct(&variant.fields, &SerdeDefault::None, deny_unknown_fields),
370    }
371}
372
373fn expr_for_untagged_enum_variant_for_flatten(
374    variant: &Variant,
375    deny_unknown_fields: bool,
376) -> Option<TokenStream> {
377    if let Some(with_attr) = &variant.attrs.with {
378        let (ty, type_def) = type_for_schema(with_attr);
379        let gen = quote!(gen);
380        let mut schema_expr = quote_spanned! {variant.original.span()=>
381            <#ty as schemars::JsonSchema>::json_schema(#gen)
382        };
383
384        prepend_type_def(type_def, &mut schema_expr);
385        return Some(schema_expr);
386    }
387
388    Some(match variant.style {
389        Style::Unit => return None,
390        Style::Newtype => expr_for_field(&variant.fields[0], false),
391        Style::Tuple => expr_for_tuple_struct(&variant.fields),
392        Style::Struct => expr_for_struct(&variant.fields, &SerdeDefault::None, deny_unknown_fields),
393    })
394}
395
396fn expr_for_unit_struct() -> TokenStream {
397    quote! {
398        gen.subschema_for::<()>()
399    }
400}
401
402fn expr_for_newtype_struct(field: &Field) -> TokenStream {
403    expr_for_field(field, true)
404}
405
406fn expr_for_tuple_struct(fields: &[Field]) -> TokenStream {
407    let fields: Vec<_> = fields
408        .iter()
409        .filter(|f| !f.serde_attrs.skip_deserializing())
410        .map(|f| expr_for_field(f, true))
411        .collect();
412    let len = fields.len() as u32;
413
414    quote! {
415        schemars::schema::Schema::Object(
416            schemars::schema::SchemaObject {
417            instance_type: Some(schemars::schema::InstanceType::Array.into()),
418            array: Some(Box::new(schemars::schema::ArrayValidation {
419                items: Some(vec![#(#fields),*].into()),
420                max_items: Some(#len),
421                min_items: Some(#len),
422                ..Default::default()
423            })),
424            ..Default::default()
425        })
426    }
427}
428
429fn expr_for_struct(
430    fields: &[Field],
431    default: &SerdeDefault,
432    deny_unknown_fields: bool,
433) -> TokenStream {
434    let (flattened_fields, property_fields): (Vec<_>, Vec<_>) = fields
435        .iter()
436        .filter(|f| !f.serde_attrs.skip_deserializing() || !f.serde_attrs.skip_serializing())
437        .partition(|f| f.serde_attrs.flatten());
438
439    let set_container_default = match default {
440        SerdeDefault::None => None,
441        SerdeDefault::Default => Some(quote!(let container_default = Self::default();)),
442        SerdeDefault::Path(path) => Some(quote!(let container_default = #path();)),
443    };
444
445    let properties: Vec<_> = property_fields
446        .into_iter()
447        .map(|field| {
448            let name = field.name();
449            let default = field_default_expr(field, set_container_default.is_some());
450
451            let (ty, type_def) = type_for_field_schema(field);
452
453            let has_default = default.is_some();
454            let required = field.validation_attrs.required();
455
456            let metadata = SchemaMetadata {
457                read_only: field.serde_attrs.skip_deserializing(),
458                write_only: field.serde_attrs.skip_serializing(),
459                default,
460                ..field.attrs.as_metadata()
461            };
462
463            let gen = quote!(gen);
464            let mut schema_expr = if field.validation_attrs.required() {
465                quote_spanned! {ty.span()=>
466                    <#ty as schemars::JsonSchema>::_schemars_private_non_optional_json_schema(#gen)
467                }
468            } else {
469                quote_spanned! {ty.span()=>
470                    #gen.subschema_for::<#ty>()
471                }
472            };
473
474            metadata.apply_to_schema(&mut schema_expr);
475            field.validation_attrs.apply_to_schema(&mut schema_expr);
476
477            quote! {
478                {
479                    #type_def
480                    schemars::_private::insert_object_property::<#ty>(object_validation, #name, #has_default, #required, #schema_expr);
481                }
482            }
483        })
484        .collect();
485
486    let flattens: Vec<_> = flattened_fields
487        .into_iter()
488        .map(|field| {
489            let (ty, type_def) = type_for_field_schema(field);
490
491            let required = field.validation_attrs.required();
492
493            let args = quote!(gen, #required);
494            let mut schema_expr = quote_spanned! {ty.span()=>
495                schemars::_private::json_schema_for_flatten::<#ty>(#args)
496            };
497
498            prepend_type_def(type_def, &mut schema_expr);
499            schema_expr
500        })
501        .collect();
502
503    let set_additional_properties = if deny_unknown_fields {
504        quote! {
505            object_validation.additional_properties = Some(Box::new(false.into()));
506        }
507    } else {
508        TokenStream::new()
509    };
510    quote! {
511        {
512            #set_container_default
513            let mut schema_object = schemars::schema::SchemaObject {
514                instance_type: Some(schemars::schema::InstanceType::Object.into()),
515                ..Default::default()
516            };
517            let object_validation = schema_object.object();
518            #set_additional_properties
519            #(#properties)*
520            schemars::schema::Schema::Object(schema_object)
521            #(.flatten(#flattens))*
522        }
523    }
524}
525
526fn field_default_expr(field: &Field, container_has_default: bool) -> Option<TokenStream> {
527    let field_default = field.serde_attrs.default();
528    if field.serde_attrs.skip_serializing() || (field_default.is_none() && !container_has_default) {
529        return None;
530    }
531
532    let ty = field.ty;
533    let default_expr = match field_default {
534        SerdeDefault::None => {
535            let member = &field.member;
536            quote!(container_default.#member)
537        }
538        SerdeDefault::Default => quote!(<#ty>::default()),
539        SerdeDefault::Path(path) => quote!(#path()),
540    };
541
542    let default_expr = match field.serde_attrs.skip_serializing_if() {
543        Some(skip_if) => {
544            quote! {
545                {
546                    let default = #default_expr;
547                    if #skip_if(&default) {
548                        None
549                    } else {
550                        Some(default)
551                    }
552                }
553            }
554        }
555        None => quote!(Some(#default_expr)),
556    };
557
558    Some(if let Some(ser_with) = field.serde_attrs.serialize_with() {
559        quote! {
560            {
561                struct _SchemarsDefaultSerialize<T>(T);
562
563                impl serde::Serialize for _SchemarsDefaultSerialize<#ty>
564                {
565                    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
566                    where
567                        S: serde::Serializer
568                    {
569                        #ser_with(&self.0, serializer)
570                    }
571                }
572
573                #default_expr.map(|d| _SchemarsDefaultSerialize(d))
574            }
575        }
576    } else {
577        default_expr
578    })
579}
580
581fn schema_object(properties: TokenStream) -> TokenStream {
582    quote! {
583        schemars::schema::Schema::Object(
584            schemars::schema::SchemaObject {
585            #properties
586            ..Default::default()
587        })
588    }
589}
590
591fn prepend_type_def(type_def: Option<TokenStream>, schema_expr: &mut TokenStream) {
592    if let Some(type_def) = type_def {
593        *schema_expr = quote! {
594            {
595                #type_def
596                #schema_expr
597            }
598        }
599    }
600}