neli_proc_macros/
derive_frombytes.rs
1use proc_macro2::{Span, TokenStream as TokenStream2};
2use quote::quote;
3use syn::{
4 parse_str, AngleBracketedGenericArguments, Attribute, Fields, GenericArgument, GenericParam,
5 Ident, ItemStruct, LifetimeDef, PathArguments, Token, TraitBound, Type, TypeParamBound,
6};
7
8use crate::shared::{process_input, process_lifetime, process_size, StructInfo};
9
10fn add_lifetime(trt: &mut TraitBound, lt: &LifetimeDef) {
11 trt.path.segments.iter_mut().for_each(|elem| {
12 if elem.ident == parse_str::<Ident>("FromBytes").unwrap()
13 || elem.ident == parse_str::<Ident>("FromBytesWithInput").unwrap()
14 {
15 if let PathArguments::AngleBracketed(ref mut args) = elem.arguments {
16 args.args = std::iter::once(GenericArgument::Lifetime(lt.lifetime.clone()))
17 .chain(args.args.clone())
18 .collect();
19 } else if let PathArguments::None = elem.arguments {
20 elem.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
21 colon2_token: Some(Token)),
22 lt_token: Token),
23 args: std::iter::once(GenericArgument::Lifetime(lt.lifetime.clone())).collect(),
24 gt_token: Token),
25 });
26 }
27 }
28 });
29}
30
31fn process_attrs(lt: &LifetimeDef, field_type: Type, field_attrs: Vec<Attribute>) -> TokenStream2 {
32 let input = process_input(&field_attrs);
33 let size = process_size(&field_attrs)
34 .unwrap_or_else(|| parse_str("input").expect("input is a valid expression"));
35 match input {
36 Some(Some(input)) => quote! {
37 {
38 let input = #input;
39 log::trace!(
40 "Deserializing field type {}",
41 std::any::type_name::<#field_type>(),
42 );
43 let position = buffer.position() as usize;
44 log::trace!(
45 "Buffer to be deserialized: {:?}",
46 &buffer.get_ref()[position..position + #size],
47 );
48 let ok = <#field_type as neli::FromBytesWithInput<#lt>>::from_bytes_with_input(
49 buffer,
50 input,
51 )?;
52 log::trace!("Field deserialized: {:?}", ok);
53 ok
54 }
55 },
56 Some(None) => quote! {
57 {
58 log::trace!(
59 "Deserializing field type {}",
60 std::any::type_name::<#field_type>(),
61 );
62 let position = buffer.position() as usize;
63 log::trace!(
64 "Buffer to be deserialized: {:?}",
65 &buffer.get_ref()[position..position + #size],
66 );
67 let ok = <#field_type as neli::FromBytesWithInput<#lt>>::from_bytes_with_input(
68 buffer,
69 input,
70 )?;
71 log::trace!("Field deserialized: {:?}", ok);
72 ok
73 }
74 },
75 None => quote! {
76 {
77 log::trace!(
78 "Deserializing field type {}",
79 std::any::type_name::<#field_type>(),
80 );
81 let position = buffer.position() as usize;
82 log::trace!(
83 "Buffer to be deserialized: {:?}",
84 &buffer.get_ref()[position..position + <#field_type as neli::TypeSize>::type_size()],
85 );
86 let ok = <#field_type as neli::FromBytes<#lt>>::from_bytes(buffer)?;
87 log::trace!("Field deserialized: {:?}", ok);
88 ok
89 }
90 },
91 }
92}
93
94pub fn impl_frombytes_struct(
95 is: ItemStruct,
96 trt: &str,
97 method_name: &str,
98 input_type: Option<TokenStream2>,
99 input: Option<TokenStream2>,
100) -> TokenStream2 {
101 let is_named = matches!(is.fields, Fields::Named(_));
102
103 let info = StructInfo::from_item_struct(is, Some(trt), "from_bytes_bound", false);
104
105 let trt = Ident::new(trt, Span::call_site());
106 let method_name = Ident::new(method_name, Span::call_site());
107
108 let (
109 struct_name,
110 mut generics,
111 generics_without_bounds,
112 field_names,
113 field_types,
114 field_attrs,
115 padded,
116 ) = info.into_tuple();
117
118 let lt = process_lifetime(&mut generics);
119
120 if field_names.is_empty() {
121 return quote! {
122 impl#generics neli::#trt<#lt> for #struct_name#generics_without_bounds {
123 #input_type
124
125 fn #method_name(buffer: &mut std::io::Cursor<&#lt [u8]> #input) -> Result<Self, neli::err::DeError> {
126 Ok(#struct_name)
127 }
128 }
129 };
130 }
131
132 let struct_expr = if is_named {
133 quote! {
134 #struct_name {
135 #( #field_names, )*
136 }
137 }
138 } else {
139 quote! {
140 #struct_name(
141 #( #field_names, )*
142 )
143 }
144 };
145
146 for generic in generics.params.iter_mut() {
147 if let GenericParam::Type(ref mut ty) = generic {
148 for bound in ty.bounds.iter_mut() {
149 if let TypeParamBound::Trait(ref mut trt) = bound {
150 add_lifetime(trt, <);
151 }
152 }
153 }
154 }
155
156 let from_bytes_exprs = field_types
157 .into_iter()
158 .zip(field_attrs.into_iter())
159 .map(|(field_type, field_attrs)| process_attrs(<, field_type, field_attrs));
160
161 let padding = if padded {
162 quote! {
163 <#struct_name#generics_without_bounds as neli::FromBytes<#lt>>::strip(buffer)?;
164 }
165 } else {
166 TokenStream2::new()
167 };
168
169 quote! {
170 impl#generics neli::#trt<#lt> for #struct_name#generics_without_bounds {
171 #input_type
172
173 fn #method_name(buffer: &mut std::io::Cursor<&#lt [u8]> #input) -> Result<Self, neli::err::DeError> {
174 let pos = buffer.position();
175
176 let res = {
177 let mut from_bytes_impl = || {
178 log::trace!("Deserializing data type {}", stringify!(#struct_name));
179 #(
180 let #field_names = #from_bytes_exprs;
181 )*
182 #padding
183 Ok(#struct_expr)
184 };
185 from_bytes_impl()
186 };
187
188 match res {
189 Ok(res) => Ok(res),
190 Err(e) => {
191 buffer.set_position(pos);
192 Err(e)
193 },
194 }
195 }
196 }
197 }
198}