neli_proc_macros/
derive_frombytes.rs

1use proc_macro2::{Span, TokenStream as TokenStream2};
2use quote::quote;
3use syn::{
4    parse_str, AngleBracketedGenericArguments, Attribute, Fields, GenericArgument, GenericParam,
5    Ident, ItemStruct, LifetimeDef, PathArguments, Token, TraitBound, Type, TypeParamBound,
6};
7
8use crate::shared::{process_input, process_lifetime, process_size, StructInfo};
9
10fn add_lifetime(trt: &mut TraitBound, lt: &LifetimeDef) {
11    trt.path.segments.iter_mut().for_each(|elem| {
12        if elem.ident == parse_str::<Ident>("FromBytes").unwrap()
13            || elem.ident == parse_str::<Ident>("FromBytesWithInput").unwrap()
14        {
15            if let PathArguments::AngleBracketed(ref mut args) = elem.arguments {
16                args.args = std::iter::once(GenericArgument::Lifetime(lt.lifetime.clone()))
17                    .chain(args.args.clone())
18                    .collect();
19            } else if let PathArguments::None = elem.arguments {
20                elem.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
21                    colon2_token: Some(Token![::](Span::call_site())),
22                    lt_token: Token![<](Span::call_site()),
23                    args: std::iter::once(GenericArgument::Lifetime(lt.lifetime.clone())).collect(),
24                    gt_token: Token![>](Span::call_site()),
25                });
26            }
27        }
28    });
29}
30
31fn process_attrs(lt: &LifetimeDef, field_type: Type, field_attrs: Vec<Attribute>) -> TokenStream2 {
32    let input = process_input(&field_attrs);
33    let size = process_size(&field_attrs)
34        .unwrap_or_else(|| parse_str("input").expect("input is a valid expression"));
35    match input {
36        Some(Some(input)) => quote! {
37            {
38                let input = #input;
39                log::trace!(
40                    "Deserializing field type {}",
41                    std::any::type_name::<#field_type>(),
42                );
43                let position = buffer.position() as usize;
44                log::trace!(
45                    "Buffer to be deserialized: {:?}",
46                    &buffer.get_ref()[position..position + #size],
47                );
48                let ok = <#field_type as neli::FromBytesWithInput<#lt>>::from_bytes_with_input(
49                    buffer,
50                    input,
51                )?;
52                log::trace!("Field deserialized: {:?}", ok);
53                ok
54            }
55        },
56        Some(None) => quote! {
57            {
58                log::trace!(
59                    "Deserializing field type {}",
60                    std::any::type_name::<#field_type>(),
61                );
62                let position = buffer.position() as usize;
63                log::trace!(
64                    "Buffer to be deserialized: {:?}",
65                    &buffer.get_ref()[position..position + #size],
66                );
67                let ok = <#field_type as neli::FromBytesWithInput<#lt>>::from_bytes_with_input(
68                    buffer,
69                    input,
70                )?;
71                log::trace!("Field deserialized: {:?}", ok);
72                ok
73            }
74        },
75        None => quote! {
76            {
77                log::trace!(
78                    "Deserializing field type {}",
79                    std::any::type_name::<#field_type>(),
80                );
81                let position = buffer.position() as usize;
82                log::trace!(
83                    "Buffer to be deserialized: {:?}",
84                    &buffer.get_ref()[position..position + <#field_type as neli::TypeSize>::type_size()],
85                );
86                let ok = <#field_type as neli::FromBytes<#lt>>::from_bytes(buffer)?;
87                log::trace!("Field deserialized: {:?}", ok);
88                ok
89            }
90        },
91    }
92}
93
94pub fn impl_frombytes_struct(
95    is: ItemStruct,
96    trt: &str,
97    method_name: &str,
98    input_type: Option<TokenStream2>,
99    input: Option<TokenStream2>,
100) -> TokenStream2 {
101    let is_named = matches!(is.fields, Fields::Named(_));
102
103    let info = StructInfo::from_item_struct(is, Some(trt), "from_bytes_bound", false);
104
105    let trt = Ident::new(trt, Span::call_site());
106    let method_name = Ident::new(method_name, Span::call_site());
107
108    let (
109        struct_name,
110        mut generics,
111        generics_without_bounds,
112        field_names,
113        field_types,
114        field_attrs,
115        padded,
116    ) = info.into_tuple();
117
118    let lt = process_lifetime(&mut generics);
119
120    if field_names.is_empty() {
121        return quote! {
122            impl#generics neli::#trt<#lt> for #struct_name#generics_without_bounds {
123                #input_type
124
125                fn #method_name(buffer: &mut std::io::Cursor<&#lt [u8]> #input) -> Result<Self, neli::err::DeError> {
126                    Ok(#struct_name)
127                }
128            }
129        };
130    }
131
132    let struct_expr = if is_named {
133        quote! {
134            #struct_name {
135                #( #field_names, )*
136            }
137        }
138    } else {
139        quote! {
140            #struct_name(
141                #( #field_names, )*
142            )
143        }
144    };
145
146    for generic in generics.params.iter_mut() {
147        if let GenericParam::Type(ref mut ty) = generic {
148            for bound in ty.bounds.iter_mut() {
149                if let TypeParamBound::Trait(ref mut trt) = bound {
150                    add_lifetime(trt, &lt);
151                }
152            }
153        }
154    }
155
156    let from_bytes_exprs = field_types
157        .into_iter()
158        .zip(field_attrs.into_iter())
159        .map(|(field_type, field_attrs)| process_attrs(&lt, field_type, field_attrs));
160
161    let padding = if padded {
162        quote! {
163            <#struct_name#generics_without_bounds as neli::FromBytes<#lt>>::strip(buffer)?;
164        }
165    } else {
166        TokenStream2::new()
167    };
168
169    quote! {
170        impl#generics neli::#trt<#lt> for #struct_name#generics_without_bounds {
171            #input_type
172
173            fn #method_name(buffer: &mut std::io::Cursor<&#lt [u8]> #input) -> Result<Self, neli::err::DeError> {
174                let pos = buffer.position();
175
176                let res = {
177                    let mut from_bytes_impl = || {
178                        log::trace!("Deserializing data type {}", stringify!(#struct_name));
179                        #(
180                            let #field_names = #from_bytes_exprs;
181                        )*
182                        #padding
183                        Ok(#struct_expr)
184                    };
185                    from_bytes_impl()
186                };
187
188                match res {
189                    Ok(res) => Ok(res),
190                    Err(e) => {
191                        buffer.set_position(pos);
192                        Err(e)
193                    },
194                }
195            }
196        }
197    }
198}