schemars_derive/attr/
schemars_to_serde.rs

1use quote::ToTokens;
2use serde_derive_internals::Ctxt;
3use std::collections::HashSet;
4use syn::parse::Parser;
5use syn::{Attribute, Data, Field, Meta, Variant};
6
7use super::get_meta_items;
8
9// List of keywords that can appear in #[serde(...)]/#[schemars(...)] attributes which we want serde_derive_internals to parse for us.
10pub(crate) static SERDE_KEYWORDS: &[&str] = &[
11    "rename",
12    "rename_all",
13    "deny_unknown_fields",
14    "tag",
15    "content",
16    "untagged",
17    "default",
18    "skip",
19    "skip_serializing",
20    "skip_serializing_if",
21    "skip_deserializing",
22    "flatten",
23    "remote",
24    "transparent",
25    // Special case - `bound` is removed from serde attrs, so is only respected when present in schemars attr.
26    "bound",
27    // Special cases - `with`/`serialize_with` are passed to serde but not copied from schemars attrs to serde attrs.
28    // This is because we want to preserve any serde attribute's `serialize_with` value to determine whether the field's
29    // default value should be serialized. We also check the `with` value on schemars/serde attrs e.g. to support deriving
30    // JsonSchema on remote types, but we parse that ourselves rather than using serde_derive_internals.
31    "serialize_with",
32    "with",
33];
34
35// If a struct/variant/field has any #[schemars] attributes, then create copies of them
36// as #[serde] attributes so that serde_derive_internals will parse them for us.
37pub fn process_serde_attrs(input: &mut syn::DeriveInput) -> syn::Result<()> {
38    let ctxt = Ctxt::new();
39    process_attrs(&ctxt, &mut input.attrs);
40    match input.data {
41        Data::Struct(ref mut s) => process_serde_field_attrs(&ctxt, s.fields.iter_mut()),
42        Data::Enum(ref mut e) => process_serde_variant_attrs(&ctxt, e.variants.iter_mut()),
43        Data::Union(ref mut u) => process_serde_field_attrs(&ctxt, u.fields.named.iter_mut()),
44    };
45
46    ctxt.check()
47}
48
49fn process_serde_variant_attrs<'a>(ctxt: &Ctxt, variants: impl Iterator<Item = &'a mut Variant>) {
50    for v in variants {
51        process_attrs(ctxt, &mut v.attrs);
52        process_serde_field_attrs(ctxt, v.fields.iter_mut());
53    }
54}
55
56fn process_serde_field_attrs<'a>(ctxt: &Ctxt, fields: impl Iterator<Item = &'a mut Field>) {
57    for f in fields {
58        process_attrs(ctxt, &mut f.attrs);
59    }
60}
61
62fn process_attrs(ctxt: &Ctxt, attrs: &mut Vec<Attribute>) {
63    // Remove #[serde(...)] attributes (some may be re-added later)
64    let (serde_attrs, other_attrs): (Vec<_>, Vec<_>) =
65        attrs.drain(..).partition(|at| at.path().is_ident("serde"));
66    *attrs = other_attrs;
67
68    // Copy appropriate #[schemars(...)] attributes to #[serde(...)] attributes
69    let (mut serde_meta, mut schemars_meta_names): (Vec<_>, HashSet<_>) =
70        get_meta_items(attrs, "schemars", ctxt, false)
71            .into_iter()
72            .filter_map(|meta| {
73                let keyword = get_meta_ident(&meta)?;
74                if SERDE_KEYWORDS.contains(&keyword.as_ref()) && !keyword.ends_with("with") {
75                    Some((meta, keyword))
76                } else {
77                    None
78                }
79            })
80            .unzip();
81
82    if schemars_meta_names.contains("skip") {
83        schemars_meta_names.insert("skip_serializing".to_string());
84        schemars_meta_names.insert("skip_deserializing".to_string());
85    }
86
87    // Re-add #[serde(...)] attributes that weren't overridden by #[schemars(...)] attributes
88    for meta in get_meta_items(&serde_attrs, "serde", ctxt, false) {
89        if let Some(i) = get_meta_ident(&meta) {
90            if !schemars_meta_names.contains(&i)
91                && SERDE_KEYWORDS.contains(&i.as_ref())
92                && i != "bound"
93            {
94                serde_meta.push(meta);
95            }
96        }
97    }
98
99    if !serde_meta.is_empty() {
100        let new_serde_attr = quote! {
101            #[serde(#(#serde_meta),*)]
102        };
103
104        let parser = Attribute::parse_outer;
105        match parser.parse2(new_serde_attr) {
106            Ok(ref mut parsed) => attrs.append(parsed),
107            Err(e) => ctxt.error_spanned_by(to_tokens(attrs), e),
108        }
109    }
110}
111
112fn to_tokens(attrs: &[Attribute]) -> impl ToTokens {
113    let mut tokens = proc_macro2::TokenStream::new();
114    for attr in attrs {
115        attr.to_tokens(&mut tokens);
116    }
117    tokens
118}
119
120fn get_meta_ident(meta: &Meta) -> Option<String> {
121    meta.path().get_ident().map(|i| i.to_string())
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use pretty_assertions::assert_eq;
128    use syn::DeriveInput;
129
130    #[test]
131    fn test_process_serde_attrs() {
132        let mut input: DeriveInput = parse_quote! {
133            #[serde(rename(serialize = "ser_name"), rename_all = "camelCase")]
134            #[serde(default, unknown_word)]
135            #[schemars(rename = "overriden", another_unknown_word)]
136            #[misc]
137            struct MyStruct {
138                /// blah blah blah
139                #[serde(skip_serializing_if = "some_fn", bound = "removed")]
140                field1: i32,
141                #[serde(serialize_with = "se", deserialize_with = "de")]
142                #[schemars(with = "with", bound = "bound")]
143                field2: i32,
144                #[schemars(skip)]
145                #[serde(skip_serializing)]
146                field3: i32,
147            }
148        };
149        let expected: DeriveInput = parse_quote! {
150            #[schemars(rename = "overriden", another_unknown_word)]
151            #[misc]
152            #[serde(rename = "overriden", rename_all = "camelCase", default)]
153            struct MyStruct {
154                #[doc = r" blah blah blah"]
155                #[serde(skip_serializing_if = "some_fn")]
156                field1: i32,
157                #[schemars(with = "with", bound = "bound")]
158                #[serde(bound = "bound", serialize_with = "se")]
159                field2: i32,
160                #[schemars(skip)]
161                #[serde(skip)]
162                field3: i32,
163            }
164        };
165
166        if let Err(e) = process_serde_attrs(&mut input) {
167            panic!("process_serde_attrs returned error: {}", e)
168        };
169
170        assert_eq!(input, expected);
171    }
172}