neli_proc_macros/
derive_frombytes.rs

1use proc_macro2::{Span, TokenStream as TokenStream2};
2use quote::quote;
3use syn::{Attribute, Fields, Ident, ItemStruct, Type, parse_str};
4
5use crate::shared::{StructInfo, process_input, process_size, process_skip_debug};
6
7fn process_attrs(field_type: Type, field_attrs: Vec<Attribute>) -> TokenStream2 {
8    let input = process_input(&field_attrs);
9    let skip_debug = process_skip_debug(&field_attrs);
10    let size = process_size(&field_attrs)
11        .unwrap_or_else(|| parse_str("input").expect("input is a valid expression"));
12    match (input, skip_debug) {
13        (Some(Some(input)), _) => quote! {
14            {
15                let input = #input;
16                log::trace!(
17                    "Deserializing field type {}",
18                    std::any::type_name::<#field_type>(),
19                );
20                let position = buffer.position() as usize;
21                let slice = buffer.get_ref().as_ref().get(position..position + #size);
22                if let Some(buf) = slice {
23                    log::trace!(
24                        "Buffer to be deserialized: {buf:?}",
25                    );
26                }
27                let ok = <#field_type as neli::FromBytesWithInput>::from_bytes_with_input(
28                    buffer,
29                    input,
30                )?;
31                log::trace!("Field deserialized: {:?}", ok);
32                ok
33            }
34        },
35        (Some(None), _) => quote! {
36            {
37                log::trace!(
38                    "Deserializing field type {}",
39                    std::any::type_name::<#field_type>(),
40                );
41                let position = buffer.position() as usize;
42                let slice = buffer.get_ref().as_ref().get(position..position + #size);
43                if let Some(buf) = slice {
44                    log::trace!(
45                        "Buffer to be deserialized: {buf:?}",
46                    );
47                }
48                let ok = <#field_type as neli::FromBytesWithInput>::from_bytes_with_input(
49                    buffer,
50                    input,
51                )?;
52                log::trace!("Field deserialized: {:?}", ok);
53                ok
54            }
55        },
56        (None, true) => quote! {
57            {
58                log::trace!(
59                    "Deserializing field type {}",
60                    std::any::type_name::<#field_type>(),
61                );
62                let position = buffer.position() as usize;
63                let ok = <#field_type as neli::FromBytes>::from_bytes(buffer)?;
64                log::trace!("Field deserialized: {:?}", ok);
65                ok
66            }
67        },
68        (None, false) => quote! {
69            {
70                log::trace!(
71                    "Deserializing field type {}",
72                    std::any::type_name::<#field_type>(),
73                );
74                let position = buffer.position() as usize;
75                let slice = buffer.get_ref()
76                    .as_ref()
77                    .get(position..position + <#field_type as neli::TypeSize>::type_size());
78                if let Some(buf) = slice {
79                    log::trace!(
80                        "Buffer to be deserialized: {buf:?}",
81                    );
82                }
83                let ok = <#field_type as neli::FromBytes>::from_bytes(buffer)?;
84                log::trace!("Field deserialized: {:?}", ok);
85                ok
86            }
87        },
88    }
89}
90
91pub fn impl_frombytes_struct(
92    is: ItemStruct,
93    trt: &str,
94    method_name: &str,
95    input_type: Option<TokenStream2>,
96    input: Option<TokenStream2>,
97) -> TokenStream2 {
98    let is_named = matches!(is.fields, Fields::Named(_));
99
100    let info = StructInfo::from_item_struct(is, Some(trt), "from_bytes_bound", false);
101
102    let trt = Ident::new(trt, Span::call_site());
103    let method_name = Ident::new(method_name, Span::call_site());
104
105    let (
106        struct_name,
107        generics,
108        generics_without_bounds,
109        field_names,
110        field_types,
111        field_attrs,
112        padded,
113    ) = info.into_tuple();
114
115    if field_names.is_empty() {
116        return quote! {
117            impl #generics neli::#trt for #struct_name #generics_without_bounds {
118                #input_type
119
120                fn #method_name(buffer: &mut std::io::Cursor<impl AsRef<[u8]>> #input) -> Result<Self, neli::err::DeError> {
121                    Ok(#struct_name)
122                }
123            }
124        };
125    }
126
127    let struct_expr = if is_named {
128        quote! {
129            #struct_name {
130                #( #field_names, )*
131            }
132        }
133    } else {
134        quote! {
135            #struct_name(
136                #( #field_names, )*
137            )
138        }
139    };
140
141    let from_bytes_exprs = field_types
142        .into_iter()
143        .zip(field_attrs)
144        .map(|(field_type, field_attrs)| process_attrs(field_type, field_attrs));
145
146    let padding = if padded {
147        quote! {
148            <#struct_name #generics_without_bounds as neli::FromBytes>::strip(buffer)?;
149        }
150    } else {
151        TokenStream2::new()
152    };
153
154    quote! {
155        impl #generics neli::#trt for #struct_name #generics_without_bounds {
156            #input_type
157
158            fn #method_name(buffer: &mut std::io::Cursor<impl AsRef<[u8]>> #input) -> Result<Self, neli::err::DeError> {
159                let pos = buffer.position();
160
161                let res = {
162                    let mut from_bytes_impl = || {
163                        log::trace!("Deserializing data type {}", stringify!(#struct_name));
164                        #(
165                            let #field_names = #from_bytes_exprs;
166                        )*
167                        #padding
168                        Ok(#struct_expr)
169                    };
170                    from_bytes_impl()
171                };
172
173                match res {
174                    Ok(res) => Ok(res),
175                    Err(e) => {
176                        buffer.set_position(pos);
177                        Err(e)
178                    },
179                }
180            }
181        }
182    }
183}